MLX¶
Metal-accelerated on Apple Silicon. Kabsch is restricted to 3D inputs (dim == 3).
Float64 operations run on CPU (Apple Silicon GPUs do not support true float64).
mlx ¶
kabsch ¶
kabsch(
P: array, Q: array, weights: array | None = None
) -> tuple[mx.array, mx.array, mx.array]
Computes the optimal rotation and translation to align P to Q.
MLX only supports 3D inputs (dim=3) due to the hardcoded 3x3 determinant correction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
P
|
array
|
Source points, shape [..., N, 3]. |
required |
Q
|
array
|
Target points, shape [..., N, 3]. |
required |
weights
|
array | None
|
Per-point weights, shape [..., N]. Non-negative, must sum to > 0. When None, all points are weighted equally. |
None
|
Returns:
| Type | Description |
|---|---|
(R, t, rmsd)
|
Rotation [..., 3, 3], translation [..., 3], and RMSD [...]. |
array
|
float16/bfloat16 inputs are upcast to float32 internally and downcast on output. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If inputs are not 3-dimensional (D != 3). |
Note
R is only stable under global translation when the cross-covariance matrix H = P_c.T @ Q_c is well-conditioned. When the smallest singular value of H is near zero, U and V from the SVD are not unique, and a small perturbation can select a different rotation. Check the singular values of H if rotation stability matters for your use case.
kabsch_umeyama ¶
kabsch_umeyama(
P: array, Q: array, weights: array | None = None
) -> tuple[mx.array, mx.array, mx.array, mx.array]
Computes the optimal rotation, translation, and scale to align P to Q (Q ~ c * R @ P + t).
MLX only supports 3D inputs (dim=3) due to the hardcoded 3x3 determinant correction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
P
|
array
|
Source points, shape [..., N, 3]. |
required |
Q
|
array
|
Target points, shape [..., N, 3]. |
required |
weights
|
array | None
|
Per-point weights, shape [..., N]. Non-negative, must sum to > 0. When None, all points are weighted equally. |
None
|
Returns:
| Type | Description |
|---|---|
(R, t, c, rmsd)
|
Rotation [..., 3, 3], translation [..., 3], |
array
|
scale [...], RMSD [...]. |
array
|
float16/bfloat16 inputs are upcast to float32 and downcast on output. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If inputs are not 3-dimensional (D != 3). |
Note
Unlike kabsch, the cross-covariance H is divided by N here. This per-point normalization is required by the Umeyama scale estimator (c = trace(S * D) / var_P) and does not affect the rotation or translation.
R is only stable under global translation and uniform scaling when the cross-covariance matrix H = P_c.T @ Q_c is well-conditioned. When the smallest singular value of H is near zero, U and V from the SVD are not unique, and a small perturbation can select a different rotation. Check the singular values of H if rotation stability matters for your use case.
horn ¶
horn(
P: array, Q: array, weights: array | None = None
) -> tuple[mx.array, mx.array, mx.array]
Computes optimal rotation and translation to align P to Q using Horn's quaternion method.
Strictly 3D only. Uses gradient-safe eigendecomposition (safe_eigh_fwd) to avoid NaN gradients when point clouds are symmetric or degenerate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
P
|
array
|
Source points as mx.array, shape [..., N, 3]. |
required |
Q
|
array
|
Target points as mx.array, shape [..., N, 3]. |
required |
weights
|
array | None
|
Per-point weights, shape [..., N]. Non-negative, must sum to > 0. When None, all points are weighted equally. |
None
|
Returns:
| Type | Description |
|---|---|
(R, t, rmsd)
|
Rotation [..., 3, 3], translation [..., 3], and RMSD [...]. |
array
|
float16/bfloat16 inputs are upcast to float32 internally and downcast on output. |
horn_with_scale ¶
horn_with_scale(
P: array, Q: array, weights: array | None = None
) -> tuple[mx.array, mx.array, mx.array, mx.array]
Computes optimal rotation, translation, and scale to align P to Q (Q ~ c * R @ P + t).
Strictly 3D only. Uses gradient-safe eigendecomposition (safe_eigh_fwd).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
P
|
array
|
Source points as mx.array, shape [..., N, 3]. |
required |
Q
|
array
|
Target points as mx.array, shape [..., N, 3]. |
required |
weights
|
array | None
|
Per-point weights, shape [..., N]. Non-negative, must sum to > 0. When None, all points are weighted equally. |
None
|
Returns:
| Type | Description |
|---|---|
(R, t, c, rmsd)
|
Rotation [..., 3, 3], translation [..., 3], |
array
|
scale [...], RMSD [...]. |
array
|
float16/bfloat16 inputs are upcast to float32 and downcast on output. |
kabsch_rmsd ¶
kabsch_rmsd(
P: array, Q: array, weights: array | None = None
) -> mx.array
Computes RMSD after Kabsch alignment. Gradient-safe training loss.
kabsch_umeyama_rmsd ¶
kabsch_umeyama_rmsd(
P: array, Q: array, weights: array | None = None
) -> mx.array
Computes RMSD after Kabsch-Umeyama alignment. Gradient-safe training loss.