Getting Started¶
Installation¶
pip¶
pip install kabsch-horn-cookbook
uv¶
uv add kabsch-horn-cookbook
Copy-the-folder¶
The code has no runtime dependencies beyond the framework itself. Copy the framework folder you need from src/kabsch_horn/<framework>/ directly into your project:
src/kabsch_horn/
├── numpy/
├── pytorch/
├── jax/
├── tensorflow/
└── mlx/
Each folder contains two files: kabsch_svd_nd.py for SVD-based alignment and horn_quat_3d.py for quaternion-based alignment. Copy one file, both, or the whole folder.
Quickstart¶
Every framework exports the same API. Replace pytorch with jax, tensorflow, mlx, or numpy as needed.
Kabsch alignment (N-dimensional)¶
import torch
from kabsch_horn import pytorch as kh
# Batched N-dimensional points
P = torch.randn(10, 100, 64)
Q = torch.randn(10, 100, 64)
# R: (10, 64, 64) | t: (10, 64) | rmsd: (10,)
R, t, rmsd = kh.kabsch(P, Q)
# Umeyama variant (with global scale)
# R: (10, 64, 64) | t: (10, 64) | c: (10,) | rmsd: (10,)
R, t, c, rmsd = kh.kabsch_umeyama(P, Q)
Horn alignment (3D quaternion)¶
P_3d = torch.randn(10, 100, 3)
Q_3d = torch.randn(10, 100, 3)
# R: (10, 3, 3) | t: (10, 3) | rmsd: (10,)
R, t, rmsd = kh.horn(P_3d, Q_3d)
# Horn with scale
R, t, c, rmsd = kh.horn_with_scale(P_3d, Q_3d)
Per-point weights¶
All alignment functions accept an optional weights parameter of shape [..., N]. Weights must be non-negative and sum to a positive value.
# Confidence scores, B-factors, attention masks, etc.
weights = torch.rand(10, 100) # shape: [..., N]
R, t, rmsd = kh.kabsch(P, Q, weights=weights)
RMSD loss for training¶
Autodiff frameworks export single-call RMSD loss functions with stable gradients:
loss = kh.kabsch_rmsd(P, Q)
loss.mean().backward()
Framework notes¶
- NumPy -- Forward-pass only. No autograd wrappers or RMSD loss functions.
- MLX -- Kabsch is restricted to 3D inputs (
dim == 3). Horn works for 3D as in all frameworks. - JAX -- Float64 requires
JAX_ENABLE_X64=Trueset before importing JAX.
Compiler support¶
All functions work with torch.compile and jax.jit:
compiled_kabsch = torch.compile(kh.kabsch)
R, t, rmsd = compiled_kabsch(P, Q)
import jax
from kabsch_horn import jax as kh
jitted_kabsch = jax.jit(kh.kabsch)
R, t, rmsd = jitted_kabsch(P, Q)