Skip to content

JAX

Custom VJP support with safe backward passes. Compatible with jax.jit.

Float64 requires JAX_ENABLE_X64=True set before importing JAX.

jax

kabsch

kabsch(
    P: ndarray, Q: ndarray, weights: ndarray | None = None
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]

Computes the optimal rotation and translation to align P to Q using gradient-safe SVD.

Parameters:

Name Type Description Default
P ndarray

Source points, shape [..., N, D].

required
Q ndarray

Target points, shape [..., N, D].

required
weights ndarray | None

Per-point weights, shape [..., N]. Non-negative, must sum to > 0. When None, all points are weighted equally.

None

Returns:

Type Description
(R, t, rmsd)

Rotation [..., D, D], translation [..., D], and RMSD [...].

ndarray

float16/bfloat16 inputs are upcast to float32 internally and downcast on output.

Note

R is only stable under global translation when the cross-covariance matrix H = P_c.T @ Q_c is well-conditioned. When the smallest singular value of H is near zero, U and V from the SVD are not unique, and a small perturbation can select a different rotation. Check the singular values of H if rotation stability matters for your use case.

kabsch_umeyama

kabsch_umeyama(
    P: ndarray, Q: ndarray, weights: ndarray | None = None
) -> tuple[
    jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray
]

Computes the optimal rotation, translation, and scale to align P to Q (Q ~ c * R @ P + t).

Parameters:

Name Type Description Default
P ndarray

Source points, shape [..., N, D].

required
Q ndarray

Target points, shape [..., N, D].

required
weights ndarray | None

Per-point weights, shape [..., N]. Non-negative, must sum to > 0. When None, all points are weighted equally.

None

Returns:

Type Description
(R, t, c, rmsd)

Rotation [..., D, D], translation [..., D],

ndarray

scale [...], RMSD [...].

ndarray

float16/bfloat16 inputs are upcast to float32 and downcast on output.

Note

Unlike kabsch, the cross-covariance H is divided by N here. This per-point normalization is required by the Umeyama scale estimator (c = trace(S * D) / var_P) and does not affect the rotation or translation.

R is only stable under global translation and uniform scaling when the cross-covariance matrix H = P_c.T @ Q_c is well-conditioned. When the smallest singular value of H is near zero, U and V from the SVD are not unique, and a small perturbation can select a different rotation. Check the singular values of H if rotation stability matters for your use case.

horn

horn(
    P: ndarray, Q: ndarray, weights: ndarray | None = None
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]

Computes optimal rotation and translation to align P to Q using Horn's quaternion method.

Strictly 3D only. Uses gradient-safe eigendecomposition (safe_eigh) to avoid NaN gradients when point clouds are symmetric or degenerate.

Parameters:

Name Type Description Default
P ndarray

Source points, shape [..., N, 3].

required
Q ndarray

Target points, shape [..., N, 3].

required
weights ndarray | None

Per-point weights, shape [..., N]. Non-negative, must sum to > 0. When None, all points are weighted equally.

None

Returns:

Type Description
(R, t, rmsd)

Rotation [..., 3, 3], translation [..., 3], and RMSD [...].

ndarray

float16/bfloat16 inputs are upcast to float32 internally and downcast on output.

horn_with_scale

horn_with_scale(
    P: ndarray, Q: ndarray, weights: ndarray | None = None
) -> tuple[
    jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray
]

Computes optimal rotation, translation, and scale to align P to Q (Q ~ c * R @ P + t).

Strictly 3D only. Uses gradient-safe eigendecomposition (safe_eigh).

Parameters:

Name Type Description Default
P ndarray

Source points, shape [..., N, 3].

required
Q ndarray

Target points, shape [..., N, 3].

required
weights ndarray | None

Per-point weights, shape [..., N]. Non-negative, must sum to > 0. When None, all points are weighted equally.

None

Returns:

Type Description
(R, t, c, rmsd)

Rotation [..., 3, 3], translation [..., 3],

ndarray

scale [...], RMSD [...].

ndarray

float16/bfloat16 inputs are upcast to float32 and downcast on output.

kabsch_rmsd

kabsch_rmsd(
    P: ndarray, Q: ndarray, weights: ndarray | None = None
) -> jnp.ndarray

Computes RMSD after Kabsch alignment. Gradient-safe training loss.

kabsch_umeyama_rmsd

kabsch_umeyama_rmsd(
    P: ndarray, Q: ndarray, weights: ndarray | None = None
) -> jnp.ndarray

Computes RMSD after Kabsch-Umeyama alignment. Gradient-safe training loss.