RoMa: A lightweight library to deal with 3D rotations in PyTorch.
RoMa (which stands for Rotation Manipulation) provides differentiable mappings between 3D rotation representations, mappings from Euclidean to rotation space, and various utilities related to rotations. It is implemented in PyTorch and aims to be an easy-to-use and reasonably efficient toolbox for Machine Learning and gradient-based optimization.
Installation
The easiest way to install RoMa is to use pip:
pip install roma
We also recommend installing torch-batch-svd
to achieve significant speed-up with procrustes()
on a CUDA GPU (see section Why a new library?).
Alternatively one can install the latest version of RoMa directly from the source repository:
pip install git+https://github.com/naver/roma
or include the source repository (https://github.com/naver/roma) as a Git submodule.
Main features
Supported rotation representations
- Rotation vector (rotvec)
Encoded using a …x3 tensor.
3D vector angle * axis represents a rotation of angle angle (expressed in radians) around a unit 3D axis.
- Unit quaternion (unitquat)
Encoded as …x4 tensor.
Note
We use XYZW quaternion convention, i.e. components of quaternion \(x i + y j + z k + w\) are represented by the 4D vector \((x,y,z,w)\).
We assume unit quaternions to be of unit length, and do not perform implicit normalization.
- Rotation matrix (rotmat)
Encoded as a …xDxD tensor (D=3 for 3D rotations).
We use column-vector convention, i.e. \(R X\) is the transformation of a 1xD vector \(X\) by a rotation matrix \(R\).
- Euler and Tait-Bryan angles are NOT currently supported.
This is because of the many different existing conventions, and because of the limited interest of such parameterization for numerical applications.
Mappings between rotation representations
RoMa provides functions to convert between rotation representations.
Example mapping a batch of rotation vectors into corresponding unit quaternions:
import torch, roma
batch_shape = (3, 2)
rotvec = torch.randn(batch_shape + (3,))
q = roma.rotvec_to_unitquat(rotvec)
Mappings from Euclidean to 3D rotation space
Mapping an arbitrary tensor to a valid rotation can be useful e.g. for Machine Learning applications. While rotation vectors or Euler angles can be used for such purpose, they suffer from various shortcomings, and we therefore provide the following alternative mappings:
special_gramschmidt()
Mapping from a 3x2 tensor to 3x3 rotation matrix, using special Gram-Schmidt orthonormalization (6D representation, popularized by Zhou et al.).
special_procrustes()
Mapping from a nxn arbitrary matrix to a nxn rotation matrix, using special orthogonal Procrustes orthonormalization.
symmatrixvec_to_unitquat()
Mapping from a 10D vector to an antipodal pair of quaternion through eigenvector decomposition of a 4x4 symmetric matrix, proposed by Peretroukhin et al..
For general purpose applications, we recommend the use of special_procrustes()
which projects an arbitrary square matrix onto the closest matrix of the rotation space,
considering Frobenius norm. Please refer to this paper for more insights.
Example mapping random 3x3 matrices to valid rotation matrices:
import torch, roma
batch_shape = (5,)
M = torch.randn(batch_shape + (3,3))
R = roma.special_procrustes(M)
assert roma.is_rotation_matrix(R, epsilon=1e-5)
Support for an arbitrary number of batch dimensions
For convenience, functions accept an arbitrary number of batch dimensions:
import torch, roma
print(roma.rotvec_to_rotmat(torch.randn(3)).shape) # -> torch.Size([3, 3])
print(roma.rotvec_to_rotmat(torch.randn(5, 3)).shape) # -> torch.Size([5, 3, 3])
print(roma.rotvec_to_rotmat(torch.randn(2, 5, 3)).shape) # -> torch.Size([2, 5, 3, 3])
Quaternion operations
import torch, roma
q = torch.randn(4) # Random unnormalized quaternion
qconv = roma.quat_conjugation(q) # Quaternion conjugation
qinv = roma.quat_inverse(q) # Quaternion inverse
print(roma.quat_product(q, qinv)) # -> [0,0,0,1] identity quaternion
Rotation composition and inverse
Example using rotation vector representation:
import torch, roma
rotvecs = [roma.random_rotvec() for _ in range(3)] # Random rotation vectors
r = roma.rotvec_composition(rotvecs) # Composition of an arbitrary number of rotations
rinv = roma.rotvec_inverse(r) # Rotation vector corresponding to the inverse rotation
Rotation metrics
RoMa implements some usual similarity measures over the 3D rotation space:
import torch, roma
R1, R2 = roma.random_rotmat(size=5), roma.random_rotmat(size=5)
theta = roma.rotmat_geodesic_distance(R1, R2) # In radian
cos_theta = roma.rotmat_cosine_angle(R1.transpose(-2, -1) @ R2)
Weighted rotation averaging
special_procrustes()
can be used to easily average rotations:
import torch, roma
n = 5
R_i = roma.random_rotmat(n) # Batch of n 3x3 rotation matrices to average
w_i = torch.rand(n) # Weight of each matrix, between 0 and 1
M = torch.sum(w_i[:,None, None] * R_i, dim=0) # 3x3 matrix
R = roma.special_procrustes(M) # weighted average.
To be precise, it consists in the Fréchet mean considering the chordal distance.
Note
The same average could be performed using quaternion representation and symmatrix mapping (slower batched implementation on GPU).
Rigid registration
rigid_points_registration()
and rigid_vectors_registration()
enable to align ordered sets of points/vectors:
import torch, roma
R_gt = roma.random_rotmat()
t_gt = torch.randn(1, 3)
src = torch.randn(100, 3) # source points / vectors
target = src @ R_gt.T # target vectors
R_predicted = roma.rigid_vectors_registration(src, target)
print(f"R_gt\n{R_gt}")
print(f"R_predicted\n{R_predicted}")
target = src @ R_gt.T + t_gt # target points
R_predicted, t_predicted = roma.rigid_points_registration(src, target)
print(f"R_gt\n{R_gt}")
print(f"R_predicted\n{R_predicted}")
print(f"t_gt\n{t_gt}")
print(f"t_predicted\n{t_predicted}")
Spherical linear interpolation (SLERP)
SLERP between batches of unit quaternions:
import torch, roma
q0, q1 = roma.random_unitquat(size=10), roma.random_unitquat(size=10)
steps = torch.linspace(0, 1.0, 5)
q_interpolated = roma.utils.unitquat_slerp(q0, q1, steps)
idx = 1 # Print interpolations for an arbitrary element of the batch
print('q0:\n', q0[idx])
print('q1:\n', q1[idx])
print('q_interpolated:\n', q_interpolated[:,idx])
SLERP between rotation vectors (shortest path interpolation):
import torch, roma
steps = torch.linspace(0, 1.0, 5)
rotvec0, rotvec1 = torch.randn(3), torch.randn(3)
rotvec_interpolated = roma.rotvec_slerp(rotvec0, rotvec1, steps)
Why a new library?
- We could not find a PyTorch library satisfying our needs, so we built our own.
We wanted a reliable, easy-to-use and efficient toolbox to deal with rotation representations in PyTorch. While Kornia provides some utility functions to deal with 3D rotations, it included several major bugs at the time of writting (early 2021) (see e.g. https://github.com/kornia/kornia/issues/723 or https://github.com/kornia/kornia/issues/317).
- Care for numerical precision
RoMa is implemented with numerical precision in mind, e.g. with a special handling of small angle rotation vectors or through the choice of appropriate algorithms.
As an example, below is plotted a function that takes as input an angle \(\theta\), produces a rotation matrix \(R_z(\theta)\) of angle \(\theta\) and estimates its geodesic distance with respect to the identity matrix, using 32 bits floating point arithmetic. We observe that
rotmat_geodesic_distance()
is much more precise for this task than an other implementation often found in academic code:rotmat_geodesic_distance_naive()
. Backward pass throughrotmat_geodesic_distance_naive()
leads to unstable gradient estimations and produces Not-a-Number values for small angles, whereasrotmat_geodesic_distance_naive()
is well-behaved, and returns Not-a-Number only for 0.0 angle where gradient is mathematically undefined.- Computation efficiency
RoMa favors code clarity, but aims to be reasonably efficient. In particular, for Procrustes orthonormalization it can use on NVidia GPUs a batched SVD decomposition that provides orders of magnitude speed-ups for large batch sizes compared to vanilla
torch.svd()
(tested with random 3x3 matrices, PyTorch 1.7, a NVidia Tesla T4 GPU and CUDA 11.0).- Syntactic sugar
RoMa aims to be easy-to-use with a simple syntax, and supports of an arbitrary number of batch dimensions to let its users focus on their applications.
API Documentation
Mappings
Various mappings between different rotation representations.
- procrustes(M, force_rotation=False, regularization=0.0, gradient_eps=1e-05)
Returns the orthonormal matrix \(R\) minimizing Frobenius norm \(\| M - R \|_F\).
- Parameters
M (...xNxN tensor) – batch of square matrices.
force_rotation (bool) – if True, forces the output to be a rotation matrix.
regularization (float >= 0) – weight of a regularization term added to the gradient. Using this regularization is equivalent to adding a term \(regularization * \| M - R \|_F^2\) to the training loss function.
gradient_eps (float > 0) – small value used to enforce numerical stability during backpropagation.
- Returns
batch of orthonormal matrices (…xNxN tensor).
- procrustes_naive(M, force_rotation: bool = False)
Implementation of
procrustes()
relying on default backward pass of autograd and SVD decomposition. Could be slightly less stable thanprocrustes()
.
- rotmat_to_rotvec(R)
Converts rotation matrix to rotation vector representation.
- Parameters
R (...x3x3 tensor) – batch of rotation matrices.
- Returns
batch of rotation vectors (…x3 tensor).
- rotmat_to_unitquat(R)
Converts rotation matrix to unit quaternion representation.
- Parameters
R (...x3x3 tensor) – batch of rotation matrices.
- Returns
batch of unit quaternions (…x4 tensor, XYZW convention).
- rotvec_to_rotmat(rotvec: Tensor, epsilon=1e-06) Tensor
Converts rotation vector to rotation matrix representation. Conversion uses Rodrigues formula in general, and a first order approximation for small angles.
- Parameters
rotvec (...x3 tensor) – batch of rotation vectors.
epsilon (float) – small angle threshold.
- Returns
batch of rotation matrices (…x3x3 tensor).
- rotvec_to_unitquat(rotvec)
Converts rotation vector into unit quaternion representation.
- Parameters
rotvec (...x3 tensor) – batch of rotation vectors.
- Returns
batch of unit quaternions (…x4 tensor, XYZW convention).
- special_gramschmidt(M, epsilon=0)
Returns the 3x3 rotation matrix obtained by Gram-Schmidt orthonormalization of two 3D input vectors (first two columns of input matrix M).
- Parameters
M (...x3xN tensor) – batch of 3xN matrices, with N >= 2. Only the first two columns of the matrices are used for orthonormalization.
epsilon (float >= 0) – optional clamping value to avoid returning Not-a-Number values in case of ill-defined input.
- Returns
batch of rotation matrices (…x3x3 tensor).
Warning
In case of ill-defined input (colinear input column vectors), the output will not be a rotation matrix.
- special_procrustes(M, regularization=0.0, gradient_eps=1e-05)
Returns the rotation matrix \(R\) minimizing Frobenius norm \(\| M - R \|_F\).
- Parameters
M (...xNxN tensor) – batch of square matrices.
regularization (float >= 0) – weight of a regularization term added to the gradient. Using this regularization is equivalent to adding a term \(regularization * \| M - R \|_F^2\) to the training loss function.
gradient_eps (float > 0) – small value used to enforce numerical stability during backpropagation.
- Returns
batch of rotation matrices (…xNxN tensor).
- special_procrustes_naive(M)
Implementation of
special_procrustes()
relying on default backward pass of autograd and SVD decomposition. Could be slightly less stable thanspecial_procrustes()
.
- symmatrix_to_projective_point(A)
Converts a DxD symmetric matrix A into a projective point represented by a unit vector \(q\) minimizing \(q^T A q\).
- Parameters
A (...xDxD tensor) – batch of symmetric matrices. Only the lower triangular part is considered.
- Returns
batch of unit vectors \(q\) (…xD tensor).
- Reference:
Peretroukhin, M. Giamou, D. M. Rosen, W. N. Greene, N. Roy, and J. Kelly, “A Smooth Representation of Belief over SO(3) for Deep Rotation Learning with Uncertainty,” 2020, doi: 10.15607/RSS.2020.XVI.007.
Warning
This mapping is unstable when the smallest eigenvalue of A has a multiplicity strictly greater than 1.
The eigenvalue decomposition may fail, in particular when using single precision numbers.
Current implementation is rather slow due to the implementation of
torch.symeig
. The CuSolver library provides a faster eigenvalue decomposition alternative, but results where found to be unreliable.
- symmatrixvec_to_unitquat(x)
Converts a 10D vector into a unit quaternion representation. Based on
symmatrix_to_projective_point()
.- Parameters
x (...x10 tensor) – batch of 10D vectors.
- Returns
batch of unit quaternions (…x4 tensor, XYZW convention).
- Reference:
Peretroukhin, M. Giamou, D. M. Rosen, W. N. Greene, N. Roy, and J. Kelly, “A Smooth Representation of Belief over SO(3) for Deep Rotation Learning with Uncertainty,” 2020, doi: 10.15607/RSS.2020.XVI.007.
- unitquat_to_rotmat(quat)
Converts unit quaternion into rotation matrix representation.
- Parameters
quat (...x4 tensor, XYZW convention) – batch of unit quaternions. No normalization is applied before computation.
- Returns
batch of rotation matrices (…x3x3 tensor).
- unitquat_to_rotvec(quat)
Converts unit quaternion into rotation vector representation.
- Parameters
quat (...x4 tensor, XYZW convention) – batch of unit quaternions. No normalization is applied before computation.
- Returns
batch of rotation vectors (…x3 tensor).
Utils
Various utility functions related to rotation representations.
- is_orthonormal_matrix(R, epsilon=1e-07)
Test if matrices are orthonormal.
- Parameters
R (...xDxD tensor) – batch of square matrices.
epsilon – tolerance threshold.
- Returns
boolean tensor (shape …).
- is_rotation_matrix(R, epsilon=1e-07)
Test if matrices are rotation matrices.
- Parameters
R (...xDxD tensor) – batch of square matrices.
epsilon – tolerance threshold.
- Returns
boolean tensor (shape …).
- quat_composition(sequence, normalize=False)
Returns the product of a sequence of quaternions.
- Parameters
sequence (sequence of ...x4 tensors, XYZW convention) – sequence of batches of quaternions.
normalize (bool) – it True, normalize the returned quaternion.
- Returns
batch of quaternions (…x4 tensor, XYZW convention).
- quat_conjugation(quat)
Returns the conjugation of input batch of quaternions.
- Parameters
quat (...x4 tensor, XYZW convention) – batch of quaternions.
- Returns
batch of quaternions (…x4 tensor, XYZW convention).
Note
Conjugation of a unit quaternion is equal to its inverse.
- quat_inverse(quat)
Returns the inverse of a batch of quaternions.
- Parameters
quat (...x4 tensor, XYZW convention) – batch of quaternions.
- Returns
batch of quaternions (…x4 tensor, XYZW convention).
Note
Inverse of null quaternion is undefined.
For unit quaternions, consider using conjugation instead.
- quat_product(p, q)
Returns the product of two quaternions.
- Parameters
p (...x4 tensor, XYZW convention) – batch of quaternions.
q (...x4 tensor, XYZW convention) – batch of quaternions.
- Returns
batch of quaternions (…x4 tensor, XYZW convention).
- random_rotmat(size=(), dtype=torch.float32, device=None)
Generates a batch of random 3x3 rotation matrices, uniformly sampled according to the usual rotation metric.
- Parameters
size (tuple or int) – batch size. Use for example
tuple()
to generate a single element, and(5,2)
to generate a 5x2 batch.- Returns
batch of rotation matrices (size x 3x3 tensor).
- random_rotvec(size=(), dtype=torch.float32, device=None)
Generates a batch of random rotation vectors, uniformly sampled according to the usual rotation metric.
- Parameters
size (tuple or int) – batch size. Use for example
tuple()
to generate a single element, and(5,2)
to generate a 5x2 batch.- Returns
batch of rotation vectors (size x 3 tensor).
- random_unitquat(size=(), dtype=torch.float32, device=None)
Generates a batch of random unit quaternions, uniformly sampled according to the usual quaternion metric.
- Parameters
size (tuple or int) – batch size. Use for example
tuple()
to generate a single element, and(5,2)
to generate a 5x2 batch.- Returns
batch of unit quaternions (size x 4 tensor).
- rigid_points_registration(x, y)
Returns the rigid transformation \((R,t)\) that best aligns an input list of points \((x_i)_{i=1...n}\) to a target list of points \((y_i)_{i=1...n}\), by minimizing the sum of square distance \(\sum_i \|R x_i + t - y_i\|^2\). This is sometimes referred to as the Kabsch/Umeyama algorithm.
- Parameters
x (...xNxD tensor) – list of N points of dimension D.
y (...xNxD tensor) – list of corresponding target points.
- Returns
a tuple \((R, t)\) consisting of a rotation matrix \(R\) (…xDxD tensor) and a translation vector \(t\) (…xD tensor).
References
Umeyama, “Least-squares estimation of transformation parameters between two point patterns,” IEEE Transactions on pattern analysis and machine intelligence, vol. 13, no. 4, Art. no. 4, 1991.
Kabsch, “A solution for the best rotation to relate two sets of vectors”. Acta Crystallographica, A32, 1976.
- rigid_vectors_registration(x, y)
Returns the rotation matrix \(R\) that best aligns an input list of vectors \((x_i)_{i=1...n}\) to a target list of vectors \((y_i)_{i=1...n}\), by minimizing the sum of square distance \(\sum_i \|R x_i - y_i\|^2\).
See
rigid_points_registration()
for details.- Parameters
x (...xNxD tensor) – list of N vectors of dimension D.
y (...xNxD tensor) – list of corresponding target vectors.
- Returns
The rotation matrix \(R\) (…xDxD tensor).
- rotmat_composition(sequence, normalize=False)
Returns the product of a sequence of rotation matrices.
- Parameters
sequence (sequence of ...xNxN tensors) – sequence of batches of rotation matrices.
normalize – if True, apply special Procrustes orthonormalization to compensate for numerical errors.
- Returns
batch of rotation matrices (…xNxN tensor).
- rotmat_cosine_angle(R)
Returns the cosine angle of the input 3x3 rotation matrix R. Based on the equality \(Trace(R) = 1 + 2 cos(alpha)\).
- Parameters
R (...x3x3 tensor) – batch of 3w3 rotation matrices.
- Returns
batch of cosine angles (… tensor).
- rotmat_geodesic_distance(R1, R2, clamping=1.0)
Returns the angular distance alpha between a pair of rotation matrices. Based on the equality \(|R_2 - R_1|_F = 2 \sqrt{2} sin(alpha/2)\).
- Parameters
R1 (...x3x3 tensor) – batch of 3x3 rotation matrices.
R2 (...x3x3 tensor) – batch of 3x3 rotation matrices.
clamping – clamping value applied to the input of
torch.asin()
. Use 1.0 to ensure valid angular distances. Use a value strictly smaller than 1.0 to ensure finite gradients.
- Returns
batch of angles in radians (… tensor).
- rotmat_geodesic_distance_naive(R1, R2)
Returns the angular distance between a pair of rotation matrices. Based on
rotmat_cosine_angle()
and less precise thanrotmat_geodesic_distance()
for nearby rotations.- Parameters
R1 (...x3x3 tensor) – batch of 3x3 rotation matrices.
R2 (...x3x3 tensor) – batch of 3x3 rotation matrices.
- Returns
batch of angles in radians (… tensor).
- rotmat_inverse(R)
Returns the inverse of a rotation matrix.
- Parameters
R (...xNxN tensor) – batch of rotation matrices.
- Returns
batch of inverted rotation matrices (…xNxN tensor).
Warning
The function returns a transposed view of the input, therefore one should be careful with in-place operations.
- rotmat_slerp(R0, R1, steps)
Spherical linear interpolation between two rotation matrices.
- Parameters
R0 (Ax3x3 tensor) – batch of rotation matrices (A may contain multiple dimensions).
R1 (Ax3x3 tensor) – batch of rotation matrices (A may contain multiple dimensions).
steps (tensor of shape B) – interpolation steps, 0.0 corresponding to R0 and 1.0 to R1 (B may contain multiple dimensions).
- Returns
batch of interpolated rotation matrices (BxAx3x3 tensor).
- rotvec_composition(sequence, normalize=False)
Returns a rotation vector corresponding to the composition of a sequence of rotations represented by rotation vectors. Composition is performed using an intermediary quaternion representation.
- Parameters
sequence (sequence of ...x3 tensors) – sequence of batches of rotation vectors.
normalize (bool) – if True, normalize intermediary representation to compensate for numerical errors.
- rotvec_inverse(rotvec)
Returns the inverse of the input rotation expressed using rotation vector representation.
- Parameters
rotvec (...x3 tensor) – batch of rotation vectors.
- Returns
batch of rotation vectors (…x3 tensor).
- rotvec_slerp(rotvec0, rotvec1, steps)
Spherical linear interpolation between two rotation vector representations.
- Parameters
rotvec0 (Ax3 tensor) – batch of rotation vectors (A may contain multiple dimensions).
rotvec1 (Ax3 tensor) – batch of rotation vectors (A may contain multiple dimensions).
steps (tensor of shape B) – interpolation steps, 0.0 corresponding to rotvec0 and 1.0 to rotvec1 (B may contain multiple dimensions).
- Returns
batch of interpolated rotation vectors (BxAx3 tensor).
- unitquat_slerp(q0, q1, steps, shortest_path=False)
Spherical linear interpolation between two unit quaternions.
- Parameters
q0 (Ax4 tensor) – batch of unit quaternions (A may contain multiple dimensions).
q1 (Ax4 tensor) – batch of unit quaternions (A may contain multiple dimensions).
steps (tensor of shape B) – interpolation steps, 0.0 corresponding to q0 and 1.0 to q1 (B may contain multiple dimensions).
shortest_path (boolean) – if True, interpolation will be performed along the shortest path on SO(3).
- Returns
batch of interpolated quaternions (BxAx4 tensor).
Note
When considering quaternions as rotation representations, one should keep in mind that spherical interpolation is not necessarily performed along the shortest arc, depending on the sign of
torch.sum(q0*q1,dim=-1)
.
Advanced
Running unit tests
from source repository:
python -m unittest
Building Sphinx documentation
From source repository:
./build_doc.sh
License
RoMa, Copyright (c) 2021 NAVER Corp., is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license (see license).
Bits of code were adapted from SciPy. Documentation is generated, distributed and displayed with the support of Sphinx and other materials (see notice).
References
For a more in-depth discussion regarding differentiable mappings on the rotation space, please refer to:
Romain Brégier, Deep Regression on Manifolds: A 3D Rotation Case Study. in 2021 International Conference on 3D Vision (3DV), 2021. (https://arxiv.org/abs/2103.16317).
Please cite this work in your publications:
@inproceedings{bregier2021deepregression,
title={Deep Regression on Manifolds: a {3D} Rotation Case Study},
author={Br{\'e}gier, Romain},
journal={2021 International Conference on 3D Vision (3DV)},
year={2021}
}
Changelog
- Version 1.2.6:
Added an optional regularization argument to
special_procrustes()
.Added an optional clamping argument to
rotmat_geodesic_distance()
.Fix:
rotvec_to_rotmat()
no longer produces nonfinite gradient for null rotation vectors.
- Version 1.2.5:
Added an optional regularization argument for Procrustes orthonormalization.
Added a rigid registration example in the documentation.
- Version 1.2.4:
Procrustes: automatic fallback to vanilla SVD decomposition for large dimensions.
- Version 1.2.3:
Improved support for double precision tensors.
- Version 1.2.2:
Added
rigid_points_registration()
andrigid_vectors_registration()
.Added
rotmat_slerp()
.Circumvented a deprecation warning with
torch.symeig()
when using recent PyTorch versions.
- Version 1.2.1:
Open-source release.