RoMa: A lightweight library to deal with 3D rotations in PyTorch.

RoMa (which stands for Rotation Manipulation) provides differentiable mappings between 3D rotation representations, mappings from Euclidean to rotation space, and various utilities related to rotations. It is implemented in PyTorch and aims to be an easy-to-use and reasonably efficient toolbox for Machine Learning and gradient-based optimization.

For a more in-depth discussion regarding differentiable mappings on the rotation space, please refer to:
Romain Brégier, Deep Regression on Manifolds: A 3D Rotation Case Study. in 2021 International Conference on 3D Vision (3DV), 2021. (https://arxiv.org/abs/2103.16317).

Additionally, please cite this work if you use RoMa for your research:

@inproceedings{bregier2021deepregression,
    title={Deep Regression on Manifolds: a {3D} Rotation Case Study},
    author={Br{\'e}gier, Romain},
    journal={2021 International Conference on 3D Vision (3DV)},
    year={2021}
}

Installation

The easiest way to install RoMa is to use pip:

pip install roma

Alternatively one can install the latest version of RoMa directly from the source repository:

pip install git+https://github.com/naver/roma

For pytorch versions older than 1.8, we recommend installing torch-batch-svd to achieve a significant speed-up with procrustes() on CUDA GPUs (see section Why a new library?). You can check that this module is properly loaded using the function is_torch_batch_svd_available(). With recent pytorch installations (torch>=1.8), torch-batch-svd is no longer needed or used.

Main features

Supported rotation representations

Rotation vector (rotvec)
  • Encoded using a …x3 tensor.

  • 3D vector angle * axis represents a rotation of angle angle (expressed in radians) around a unit 3D axis.

Unit quaternion (unitquat)
  • Encoded as …x4 tensor.

Note

  • We use XYZW quaternion convention, i.e. components of quaternion \(x i + y j + z k + w\) are represented by the 4D vector \((x,y,z,w)\).

  • We assume unit quaternions to be of unit length, and do not perform implicit normalization.

Rotation matrix (rotmat)
  • Encoded as a …xDxD tensor (D=3 for 3D rotations).

  • We use column-vector convention, i.e. \(R X\) is the transformation of a 1xD vector \(X\) by a rotation matrix \(R\).

Euler and Tait-Bryan angles are NOT currently supported.

This is because of the many different existing conventions, and because of the limited interest of such parameterization for numerical applications.

Mappings between rotation representations

RoMa provides functions to convert between rotation representations.

Example mapping a batch of rotation vectors into corresponding unit quaternions:

import torch, roma
batch_shape = (3, 2)
rotvec = torch.randn(batch_shape + (3,))
q = roma.rotvec_to_unitquat(rotvec)

Mappings from Euclidean to 3D rotation space

Mapping an arbitrary tensor to a valid rotation can be useful e.g. for Machine Learning applications. While rotation vectors or Euler angles can be used for such purpose, they suffer from various shortcomings, and we therefore provide the following alternative mappings:

special_gramschmidt()

Mapping from a 3x2 tensor to 3x3 rotation matrix, using special Gram-Schmidt orthonormalization (6D representation, popularized by Zhou et al.).

special_procrustes()

Mapping from a nxn arbitrary matrix to a nxn rotation matrix, using special orthogonal Procrustes orthonormalization.

symmatrixvec_to_unitquat()

Mapping from a 10D vector to an antipodal pair of quaternion through eigenvector decomposition of a 4x4 symmetric matrix, proposed by Peretroukhin et al..

For general purpose applications, we recommend the use of special_procrustes() which projects an arbitrary square matrix onto the closest matrix of the rotation space, considering Frobenius norm. Please refer to this paper for more insights.

Example mapping random 3x3 matrices to valid rotation matrices:

import torch, roma
batch_shape = (5,)
M = torch.randn(batch_shape + (3,3))
R = roma.special_procrustes(M)
assert roma.is_rotation_matrix(R, epsilon=1e-5)

Support for an arbitrary number of batch dimensions

For convenience, functions accept an arbitrary number of batch dimensions:

import torch, roma
print(roma.rotvec_to_rotmat(torch.randn(3)).shape) # -> torch.Size([3, 3])
print(roma.rotvec_to_rotmat(torch.randn(5, 3)).shape) # -> torch.Size([5, 3, 3])
print(roma.rotvec_to_rotmat(torch.randn(2, 5, 3)).shape) # -> torch.Size([2, 5, 3, 3])

Quaternion operations

import torch, roma
q = torch.randn(4) # Random unnormalized quaternion
qconv = roma.quat_conjugation(q) # Quaternion conjugation
qinv = roma.quat_inverse(q) # Quaternion inverse
print(roma.quat_product(q, qinv)) # -> [0,0,0,1] identity quaternion

Rotation composition and inverse

Example using rotation vector representation:

import torch, roma
rotvecs = [roma.random_rotvec() for _ in range(3)] # Random rotation vectors
r = roma.rotvec_composition(rotvecs) # Composition of an arbitrary number of rotations
rinv = roma.rotvec_inverse(r) # Rotation vector corresponding to the inverse rotation

Rotation metrics

RoMa implements some usual similarity measures over the 3D rotation space:

import torch, roma
R1, R2 = roma.random_rotmat(size=5), roma.random_rotmat(size=5)
theta = roma.rotmat_geodesic_distance(R1, R2) # In radian
cos_theta = roma.rotmat_cosine_angle(R1.transpose(-2, -1) @ R2)

Weighted rotation averaging

special_procrustes() can be used to easily average rotations:

import torch, roma
n = 5
R_i = roma.random_rotmat(n) # Batch of n 3x3 rotation matrices to average
w_i = torch.rand(n) # Weight of each matrix, between 0 and 1
M = torch.sum(w_i[:,None, None] * R_i, dim=0) # 3x3 matrix
R = roma.special_procrustes(M) # weighted average.

To be precise, it consists in the Fréchet mean considering the chordal distance.

Note

The same average could be performed using quaternion representation and symmatrix mapping (slower batched implementation on GPU).

Rigid registration

rigid_points_registration() and rigid_vectors_registration() enable to align ordered sets of points/vectors:

import torch, roma
R_gt = roma.random_rotmat()
t_gt = torch.randn(1, 3)
src = torch.randn(100, 3) # source points / vectors
target = src @ R_gt.T # target vectors
R_predicted = roma.rigid_vectors_registration(src, target)
print(f"R_gt\n{R_gt}")
print(f"R_predicted\n{R_predicted}")

target = src @ R_gt.T + t_gt # target points
R_predicted, t_predicted = roma.rigid_points_registration(src, target)
print(f"R_gt\n{R_gt}")
print(f"R_predicted\n{R_predicted}")
print(f"t_gt\n{t_gt}")
print(f"t_predicted\n{t_predicted}")

Spherical linear interpolation (SLERP)

SLERP between batches of unit quaternions:

import torch, roma
q0, q1 = roma.random_unitquat(size=10), roma.random_unitquat(size=10)
steps = torch.linspace(0, 1.0, 5)
q_interpolated = roma.utils.unitquat_slerp(q0, q1, steps)
idx = 1 # Print interpolations for an arbitrary element of the batch
print('q0:\n', q0[idx])
print('q1:\n', q1[idx])
print('q_interpolated:\n', q_interpolated[:,idx])

SLERP between rotation vectors (shortest path interpolation):

import torch, roma
steps = torch.linspace(0, 1.0, 5)
rotvec0, rotvec1 = torch.randn(3), torch.randn(3)
rotvec_interpolated = roma.rotvec_slerp(rotvec0, rotvec1, steps)

Why a new library?

We could not find a PyTorch library satisfying our needs, so we built our own.

We wanted a reliable, easy-to-use and efficient toolbox to deal with rotation representations in PyTorch. While Kornia provides some utility functions to deal with 3D rotations, it included several major bugs at the time of writting (early 2021) (see e.g. https://github.com/kornia/kornia/issues/723 or https://github.com/kornia/kornia/issues/317).

Care for numerical precision

RoMa is implemented with numerical precision in mind, e.g. with a special handling of small angle rotation vectors or through the choice of appropriate algorithms.

As an example, below is plotted a function that takes as input an angle \(\theta\), produces a rotation matrix \(R_z(\theta)\) of angle \(\theta\) and estimates its geodesic distance with respect to the identity matrix, using 32 bits floating point arithmetic. We observe that rotmat_geodesic_distance() is much more precise for this task than an other implementation often found in academic code: rotmat_geodesic_distance_naive(). Backward pass through rotmat_geodesic_distance_naive() leads to unstable gradient estimations and produces Not-a-Number values for small angles, whereas rotmat_geodesic_distance_naive() is well-behaved, and returns Not-a-Number only for 0.0 angle where gradient is mathematically undefined.

_images/rotmat_geodesic_distance_zero.svg_images/rotmat_geodesic_distance_grads_zero.svg
Computation efficiency

RoMa favors code clarity, but aims to be reasonably efficient.

In particular, for Procrustes orthonormalization it can use on NVidia GPUs a batched SVD decomposition that provides orders of magnitude speed-ups for large batch sizes compared to vanilla torch.svd() for PyTorch versions below 1.8. The plot below was obtained for random 3x3 matrices, with PyTorch 1.7, a NVidia Tesla T4 GPU and CUDA 11.0. Note that recent versions of pytorch (>=1.8) integrate such speed up off-the-shelf.

_images/special_procrustes_benchmark.svg
Syntactic sugar

RoMa aims to be easy-to-use with a simple syntax, and support for an arbitrary number of batch dimensions to let users focus on their applications.

API Documentation

Mappings

Various mappings between different rotation representations.

procrustes(M, force_rotation=False, regularization=0.0, gradient_eps=1e-05, return_singular_values: bool = False)

Returns the orthonormal matrix \(R\) minimizing Frobenius norm \(\| M - R \|_F\).

Parameters:
  • M (...xNxN tensor) – batch of square matrices.

  • force_rotation (bool) – if True, forces the output to be a rotation matrix.

  • regularization (float >= 0) – weight of a regularization term added to the gradient. Using this regularization is equivalent to adding a term \(regularization * \| M - R \|_F^2\) to the training loss function.

  • gradient_eps (float > 0) – small value used to enforce numerical stability during backpropagation.

Returns:

batch of orthonormal matrices (…xNxN tensor) and optional singular values. For advanced users, singular values of the SVD decomposition with sign flipping (… tensor) can optionally be returned by setting the argument return_singular_values to True.

procrustes_naive(M, force_rotation: bool = False, return_singular_values: bool = False)

Implementation of procrustes() relying on default backward pass of autograd and SVD decomposition. Could be slightly less stable than procrustes().

quat_wxyz_to_xyzw(wxyz)

Convert quaternion from WXYZ to XYZW convention.

Parameters:

wxyz (...x4 tensor, WXYZ convention) – batch of quaternions.

Returns:

batch of quaternions (…x4 tensor, XYZW convention).

quat_xyzw_to_wxyz(xyzw)

Convert quaternion from XYZW to WXYZ convention.

Parameters:

xyzw (...x4 tensor, XYZW convention) – batch of quaternions.

Returns:

batch of quaternions (…x4 tensor, WXYZ convention).

rotmat_to_rotvec(R)

Converts rotation matrix to rotation vector representation.

Parameters:

R (...x3x3 tensor) – batch of rotation matrices.

Returns:

batch of rotation vectors (…x3 tensor).

rotmat_to_unitquat(R)

Converts rotation matrix to unit quaternion representation.

Parameters:

R (...x3x3 tensor) – batch of rotation matrices.

Returns:

batch of unit quaternions (…x4 tensor, XYZW convention).

rotvec_to_rotmat(rotvec: Tensor, epsilon=1e-06) Tensor

Converts rotation vector to rotation matrix representation. Conversion uses Rodrigues formula in general, and a first order approximation for small angles.

Parameters:
  • rotvec (...x3 tensor) – batch of rotation vectors.

  • epsilon (float) – small angle threshold.

Returns:

batch of rotation matrices (…x3x3 tensor).

rotvec_to_unitquat(rotvec)

Converts rotation vector into unit quaternion representation.

Parameters:

rotvec (...x3 tensor) – batch of rotation vectors.

Returns:

batch of unit quaternions (…x4 tensor, XYZW convention).

special_gramschmidt(M, epsilon=0)

Returns the 3x3 rotation matrix obtained by Gram-Schmidt orthonormalization of two 3D input vectors (first two columns of input matrix M).

Parameters:
  • M (...x3xN tensor) – batch of 3xN matrices, with N >= 2. Only the first two columns of the matrices are used for orthonormalization.

  • epsilon (float >= 0) – optional clamping value to avoid returning Not-a-Number values in case of ill-defined input.

Returns:

batch of rotation matrices (…x3x3 tensor).

Warning

In case of ill-defined input (colinear input column vectors), the output will not be a rotation matrix.

special_procrustes(M, regularization=0.0, gradient_eps=1e-05, return_singular_values: bool = False)

Returns the rotation matrix \(R\) minimizing Frobenius norm \(\| M - R \|_F\).

Parameters:
  • M (...xNxN tensor) – batch of square matrices.

  • regularization (float >= 0) – weight of a regularization term added to the gradient. Using this regularization is equivalent to adding a term \(regularization * \| M - R \|_F^2\) to the training loss function.

  • gradient_eps (float > 0) – small value used to enforce numerical stability during backpropagation.

Returns:

batch of rotation matrices (…xNxN tensor). For advanced users, singular values of the SVD decomposition with sign flipping (… tensor) can optionally be returned by setting the argument return_singular_values to True.

special_procrustes_naive(M, return_singular_values: bool = False)

Implementation of special_procrustes() relying on default backward pass of autograd and SVD decomposition. Could be slightly less stable than special_procrustes().

symmatrix_to_projective_point(A)

Converts a DxD symmetric matrix A into a projective point represented by a unit vector \(q\) minimizing \(q^T A q\). Only the lower part of the matrix is considered in practice.

Parameters:

A (...xDxD tensor) – batch of symmetric matrices. Only the lower triangular part is considered.

Returns:

batch of unit vectors \(q\) (…xD tensor).

Reference:
  1. Peretroukhin, M. Giamou, D. M. Rosen, W. N. Greene, N. Roy, and J. Kelly, “A Smooth Representation of Belief over SO(3) for Deep Rotation Learning with Uncertainty,” 2020, doi: 10.15607/RSS.2020.XVI.007.

Warning

  • This mapping is unstable when the smallest eigenvalue of A has a multiplicity strictly greater than 1.

  • The eigenvalue decomposition may fail, in particular when using single precision numbers.

  • Current implementation is rather slow due to the implementation of torch.symeig. The CuSolver library provides a faster eigenvalue decomposition alternative, but results where found to be unreliable.

symmatrixvec_to_unitquat(x)

Converts a 10D vector into a unit quaternion representation. Based on symmatrix_to_projective_point().

Parameters:

x (...x10 tensor) – batch of 10D vectors.

Returns:

batch of unit quaternions (…x4 tensor, XYZW convention).

Reference:
  1. Peretroukhin, M. Giamou, D. M. Rosen, W. N. Greene, N. Roy, and J. Kelly, “A Smooth Representation of Belief over SO(3) for Deep Rotation Learning with Uncertainty,” 2020, doi: 10.15607/RSS.2020.XVI.007.

unitquat_to_rotmat(quat)

Converts unit quaternion into rotation matrix representation.

Parameters:

quat (...x4 tensor, XYZW convention) – batch of unit quaternions. No normalization is applied before computation.

Returns:

batch of rotation matrices (…x3x3 tensor).

unitquat_to_rotvec(quat, shortest_arc=True)

Converts unit quaternion into rotation vector representation.

Based on the representation of a rotation of angle \({\theta}\) and unit axis \((x,y,z)\) by the unit quaternions \(\pm [\sin({\theta} / 2) (x i + y j + z k) + \cos({\theta} / 2)]\).

Parameters:
  • quat (...x4 tensor, XYZW convention) – batch of unit quaternions. No normalization is applied before computation.

  • shortest_arc (bool) – if True, the function returns the smallest rotation vectors corresponding to the input 3D rotations, i.e. rotation vectors with a norm smaller than \(\pi\). If False, the function may return rotation vectors of norm larger than \(\pi\), depending on the sign of the input quaternions.

Returns:

batch of rotation vectors (…x3 tensor).

Note

Behavior is undefined for inputs quat=torch.as_tensor([0.0, 0.0, 0.0, -1.0]) and shortest_arc=False, as any rotation vector of angle \(2 \pi\) could be a valid representation in such case.

Utils

Various utility functions related to rotation representations.

identity_quat(size=(), dtype=torch.float32, device=None)

Return a batch of identity unit quaternions.

Parameters:

size (tuple or int) – batch size. Use for example tuple() to generate a single element, and (5,2) to generate a 5x2 batch.

Returns:

batch of identity quaternions (size x 4 tensor, XYZW convention).

Note

All returned batch quaternions refer to the same memory location. Consider cloning the output tensor prior performing any in-place operations.

is_orthonormal_matrix(R, epsilon=1e-07)

Test if matrices are orthonormal.

Parameters:
  • R (...xDxD tensor) – batch of square matrices.

  • epsilon – tolerance threshold.

Returns:

boolean tensor (shape …).

is_rotation_matrix(R, epsilon=1e-07)

Test if matrices are rotation matrices.

Parameters:
  • R (...xDxD tensor) – batch of square matrices.

  • epsilon – tolerance threshold.

Returns:

boolean tensor (shape …).

is_torch_batch_svd_available() bool

Returns True if the module ‘torch_batch_svd’ has been loaded. Returns False otherwise.

quat_action(q, v, is_normalized=False)

Rotate a 3D vector \(v=(x,y,z)\) by a rotation represented by a quaternion q.

Based on the action by conjugation \(q,v : q v q^{-1}\), considering the pure quaternion \(v=xi + yj +zk\) by abuse of notation.

Parameters:
  • q (...x4 tensor, XYZW convention) – batch of quaternions.

  • v (...x3 tensor) – batch of 3D vectors.

  • is_normalized – use True if the input quaternions are already normalized, to avoid unnecessary computations.

Returns:

batch of rotated 3D vectors (…x3 tensor).

Note

One should favor rotation matrix representation to rotate multiple vectors by the same rotation efficiently.

quat_composition(sequence, normalize=False)

Returns the product of a sequence of quaternions.

Parameters:
  • sequence (sequence of ...x4 tensors, XYZW convention) – sequence of batches of quaternions.

  • normalize (bool) – it True, normalize the returned quaternion.

Returns:

batch of quaternions (…x4 tensor, XYZW convention).

quat_conjugation(quat)

Returns the conjugation of input batch of quaternions.

Parameters:

quat (...x4 tensor, XYZW convention) – batch of quaternions.

Returns:

batch of quaternions (…x4 tensor, XYZW convention).

Note

Conjugation of a unit quaternion is equal to its inverse.

quat_inverse(quat)

Returns the inverse of a batch of quaternions.

Parameters:

quat (...x4 tensor, XYZW convention) – batch of quaternions.

Returns:

batch of quaternions (…x4 tensor, XYZW convention).

Note

  • Inverse of null quaternion is undefined.

  • For unit quaternions, consider using conjugation instead.

quat_normalize(quat)

Returns a normalized, unit norm, copy of a batch of quaternions.

Parameters:

quat (...x4 tensor, XYZW convention) – batch of quaternions.

Returns:

batch of quaternions (…x4 tensor, XYZW convention).

quat_product(p, q)

Returns the product of two quaternions.

Parameters:
  • p (...x4 tensor, XYZW convention) – batch of quaternions.

  • q (...x4 tensor, XYZW convention) – batch of quaternions.

Returns:

batch of quaternions (…x4 tensor, XYZW convention).

random_rotmat(size=(), dtype=torch.float32, device=None)

Generates a batch of random 3x3 rotation matrices, uniformly sampled according to the usual rotation metric.

Parameters:

size (tuple or int) – batch size. Use for example tuple() to generate a single element, and (5,2) to generate a 5x2 batch.

Returns:

batch of rotation matrices (size x 3x3 tensor).

random_rotvec(size=(), dtype=torch.float32, device=None)

Generates a batch of random rotation vectors, uniformly sampled according to the usual rotation metric.

Parameters:

size (tuple or int) – batch size. Use for example tuple() to generate a single element, and (5,2) to generate a 5x2 batch.

Returns:

batch of rotation vectors (size x 3 tensor).

random_unitquat(size=(), dtype=torch.float32, device=None)

Generates a batch of random unit quaternions, uniformly sampled according to the usual quaternion metric.

Parameters:

size (tuple or int) – batch size. Use for example tuple() to generate a single element, and (5,2) to generate a 5x2 batch.

Returns:

batch of unit quaternions (size x 4 tensor).

Reference:
  1. Shoemake, “Uniform Random Rotations”, in Graphics Gems III (IBM Version), Elsevier, 1992, pp. 124–132. doi: 10.1016/B978-0-08-050755-2.50036-1.

rigid_points_registration(x, y, weights=None, compute_scaling=False)

Returns the rigid transformation \((R,t)\) and the optional scaling \(s\) that best align an input list of points \((x_i)_{i=1...n}\) to a target list of points \((y_i)_{i=1...n}\), by minimizing the sum of square distance \(\sum_i w_i \|s R x_i + t - y_i\|^2\), where \((w_i)_{i=1...n}\) denotes optional positive weights. This is sometimes referred to as the Kabsch/Umeyama algorithm.

Parameters:
  • x (...xNxD tensor) – list of N points of dimension D.

  • y (...xNxD tensor) – list of corresponding target points.

  • weights (None or ...xN tensor) – optional list of weights associated to each point.

Returns:

a triplet \((R, t, s)\) consisting of a rotation matrix \(R\) (…xDxD tensor), a translation vector \(t\) (…xD tensor), and a scaling \(s\) (… tensor) if compute_scaling=True. Returns \((R, t)\) otherwise.

References

  1. Umeyama, “Least-squares estimation of transformation parameters between two point patterns,” IEEE Transactions on pattern analysis and machine intelligence, vol. 13, no. 4, Art. no. 4, 1991.

  1. Kabsch, “A solution for the best rotation to relate two sets of vectors”. Acta Crystallographica, A32, 1976.

rigid_vectors_registration(x, y, weights=None, compute_scaling=False)

Returns the rotation matrix \(R\) and the optional scaling \(s\) that best align an input list of vectors \((x_i)_{i=1...n}\) to a target list of vectors \((y_i)_{i=1...n}\) by minimizing the sum of square distance \(\sum_i w_i \|s R x_i - y_i\|^2\), where \((w_i)_{i=1...n}\) denotes optional positive weights. See rigid_points_registration() for details.

Parameters:
  • x (...xNxD tensor) – list of N vectors of dimension D.

  • y (...xNxD tensor) – list of corresponding target vectors.

  • weights (None or ...xN tensor) – optional list of weights associated to each vector.

Returns:

A tuple \((R, s)\) consisting of the rotation matrix \(R\) (…xDxD tensor) and the scaling \(s\) (… tensor) if compute_scaling=True. Returns the rotation matrix \(R\) otherwise.

rotmat_composition(sequence, normalize=False)

Returns the product of a sequence of rotation matrices.

Parameters:
  • sequence (sequence of ...xNxN tensors) – sequence of batches of rotation matrices.

  • normalize – if True, apply special Procrustes orthonormalization to compensate for numerical errors.

Returns:

batch of rotation matrices (…xNxN tensor).

rotmat_cosine_angle(R)

Returns the cosine angle of the input 3x3 rotation matrix R. Based on the equality \(Trace(R) = 1 + 2 cos(alpha)\).

Parameters:

R (...x3x3 tensor) – batch of 3w3 rotation matrices.

Returns:

batch of cosine angles (… tensor).

rotmat_geodesic_distance(R1, R2, clamping=1.0)

Returns the angular distance alpha between a pair of rotation matrices. Based on the equality \(|R_2 - R_1|_F = 2 \sqrt{2} sin(alpha/2)\).

Parameters:
  • R1 (...x3x3 tensor) – batch of 3x3 rotation matrices.

  • R2 (...x3x3 tensor) – batch of 3x3 rotation matrices.

  • clamping – clamping value applied to the input of torch.asin(). Use 1.0 to ensure valid angular distances. Use a value strictly smaller than 1.0 to ensure finite gradients.

Returns:

batch of angles in radians (… tensor).

rotmat_geodesic_distance_naive(R1, R2)

Returns the angular distance between a pair of rotation matrices. Based on rotmat_cosine_angle() and less precise than rotmat_geodesic_distance() for nearby rotations.

Parameters:
  • R1 (...x3x3 tensor) – batch of 3x3 rotation matrices.

  • R2 (...x3x3 tensor) – batch of 3x3 rotation matrices.

Returns:

batch of angles in radians (… tensor).

rotmat_inverse(R)

Returns the inverse of a rotation matrix.

Parameters:

R (...xNxN tensor) – batch of rotation matrices.

Returns:

batch of inverted rotation matrices (…xNxN tensor).

Warning

The function returns a transposed view of the input, therefore one should be careful with in-place operations.

rotmat_slerp(R0, R1, steps)

Spherical linear interpolation between two rotation matrices.

Parameters:
  • R0 (Ax3x3 tensor) – batch of rotation matrices (A may contain multiple dimensions).

  • R1 (Ax3x3 tensor) – batch of rotation matrices (A may contain multiple dimensions).

  • steps (tensor of shape B) – interpolation steps, 0.0 corresponding to R0 and 1.0 to R1 (B may contain multiple dimensions).

Returns:

batch of interpolated rotation matrices (BxAx3x3 tensor).

rotvec_composition(sequence, normalize=False)

Returns a rotation vector corresponding to the composition of a sequence of rotations represented by rotation vectors. Composition is performed using an intermediary quaternion representation.

Parameters:
  • sequence (sequence of ...x3 tensors) – sequence of batches of rotation vectors.

  • normalize (bool) – if True, normalize intermediary representation to compensate for numerical errors.

rotvec_geodesic_distance(vec1, vec2)

Returns the angular distance between rotations represented by rotation vectors. (use a conversion to unit quaternions internally).

Parameters:
  • vec1 (...x3 tensors) – batch of rotation vectors.

  • vec2 (...x3 tensors) – batch of rotation vectors.

Returns:

batch of angles in radians (… tensor).

rotvec_inverse(rotvec)

Returns the inverse of the input rotation expressed using rotation vector representation.

Parameters:

rotvec (...x3 tensor) – batch of rotation vectors.

Returns:

batch of rotation vectors (…x3 tensor).

rotvec_slerp(rotvec0, rotvec1, steps)

Spherical linear interpolation between two rotation vector representations.

Parameters:
  • rotvec0 (Ax3 tensor) – batch of rotation vectors (A may contain multiple dimensions).

  • rotvec1 (Ax3 tensor) – batch of rotation vectors (A may contain multiple dimensions).

  • steps (tensor of shape B) – interpolation steps, 0.0 corresponding to rotvec0 and 1.0 to rotvec1 (B may contain multiple dimensions).

Returns:

batch of interpolated rotation vectors (BxAx3 tensor).

unitquat_geodesic_distance(q1, q2)

Returns the angular distance alpha between rotations represented by unit quaternions. Based on the equality \(min |q_2 \pm q_1| = 2 |sin(alpha/4)|\).

Parameters:
  • q1 (...x4 tensor, XYZW convention) – batch of unit quaternions.

  • q2 (...x4 tensor, XYZW convention) – batch of unit quaternions.

Returns:

batch of angles in radians (… tensor).

unitquat_slerp(q0, q1, steps, shortest_arc=True)

Spherical linear interpolation between two unit quaternions.

Parameters:
  • q0 (Ax4 tensor) – batch of unit quaternions (A may contain multiple dimensions).

  • q1 (Ax4 tensor) – batch of unit quaternions (A may contain multiple dimensions).

  • steps (tensor of shape B) – interpolation steps, 0.0 corresponding to q0 and 1.0 to q1 (B may contain multiple dimensions).

  • shortest_arc (boolean) – if True, interpolation will be performed along the shortest arc on SO(3) from q0 to q1 or -q1.

Returns:

batch of interpolated quaternions (BxAx4 tensor).

Note

When considering quaternions as rotation representations, one should keep in mind that spherical interpolation is not necessarily performed along the shortest arc, depending on the sign of torch.sum(q0*q1,dim=-1).

Behavior is undefined when using shortest_arc=False with antipodal quaternions.

unitquat_slerp_fast(q0, q1, steps, shortest_arc=True)

Spherical linear interpolation between two unit quaternions. This function requires less computations than roma.utils.unitquat_slerp(), but is unsuitable for extrapolation (i.e. steps must be within [0,1]).

Parameters:
  • q0 (Ax4 tensor) – batch of unit quaternions (A may contain multiple dimensions).

  • q1 (Ax4 tensor) – batch of unit quaternions (A may contain multiple dimensions).

  • steps (tensor of shape B) – interpolation steps within 0.0 and 1.0, 0.0 corresponding to q0 and 1.0 to q1 (B may contain multiple dimensions).

  • shortest_arc (boolean) – if True, interpolation will be performed along the shortest arc on SO(3) from q0 to q1 or -q1.

Returns:

batch of interpolated quaternions (BxAx4 tensor).

Spatial transformations

Spatial transformations parameterized by rotation matrices, unit quaternions and more.

Example of use

import torch, roma
# Rigid transformation parameterized by a rotation matrix and a translation vector
T0 = roma.Rigid(linear=roma.random_rotmat(), translation=torch.randn(3))

# Rigid transformations parameterized by a unit quaternion and a translation vector
T1 = roma.RigidUnitQuat(linear=roma.random_unitquat(), translation=torch.randn(3))
T2 = roma.RigidUnitQuat(linear=roma.random_unitquat(), translation=torch.randn(3))

# Inverting and composing transformations
T = (T1.inverse() @ T2)

# Normalization to ensure that T is actually a rigid transformation.
T = T.normalize()

# Direct access to the translation part
T.translation += 0.5

# Transformation of points:
points = torch.randn(100,3)
# Adjusting the shape of T for proper broadcasting.
transformed_points = T[None].apply(points)

# Transformation of vectors:
vectors = torch.randn(10,20,3)
# Adjusting the shape of T for proper broadcasting.
transformed_vectors = T[None,None].linear_apply(vectors)

# Casting the transformation into an homogeneous 4x4 matrix.
M = T.to_homogeneous()

Applying a transformation

When applying a transformation to a set of points of coordinates v, The batch shape of v should be broadcastable with the batch shape of the transformation.

For example, one can sample a unique random rigid 3D transformation and use it to transform 100 random 3D points as follows:

roma.Rigid(roma.random_rotmat(), torch.randn(3))[None].apply(torch.randn(100,3))

To apply a different transformation to each point, one could use instead:

roma.Rigid(roma.random_rotmat(100), torch.randn(100,3)).apply(torch.randn(100,3))

Aliasing issues

Warning

For efficiency reasons, transformation objects do not copy input data. Be careful if you intend to do some in-place data modifications, and use the clone() method when required.

class Affine(linear, translation)

An affine transformation represented by a linear and a translation part.

Variables:
  • linear – (…xCxD tensor): batch of matrices specifying the linear part.

  • translation – (…xD tensor): batch of matrices specifying the translation part.

apply(v)

Transforms a tensor of points coordinates. See Applying a transformation.

Parameters:

v (...xD tensor) – tensor of point coordinates to transform.

Returns:

The transformed point coordinates.

as_tuple()
Returns:

a tuple of tensors containing the linear and translation parts of the transformation respectively.

clone()
Returns:

A copy of the transformation (useful to avoid aliasing issues).

compose(other)

Compose a transformation with the current one.

Parameters:

other – an other transformation of same type.

Returns:

The resulting transformation.

classmethod from_homogeneous(matrix)

Instantiate a new transformation from an input homogeneous (D+1)x(C+1) matrix. The input matrix is assumed to be normalized and to satisfy the properties of the transformation. No checks are performed.

Parameters:

matrix (...x(D+1)x(C+1) tensor) – tensor of transformations expressed in homogeneous coordinates, normalized with a last row equal to (0,…,0,1).

Returns:

The corresponding transformation.

inverse()
Returns:

The inverse transformation, when applicable.

linear_apply(v)

Transforms a tensor of vector coordinates.

Parameters:

v (...xD tensor) – tensor of vector coordinates to transform.

Returns:

The transformed vector coordinates.

See note in apply() regarding broadcasting.

linear_compose(other)

Compose the linear part of two transformations.

Parameters:

other – an other transformation of same type.

Returns:

a tensor representing the composed transformation.

linear_inverse()
Returns:

The inverse of the linear transformation, when applicable.

normalize()
Returns:

Copy of the transformation, normalized to ensure the class properties (for example to ensure that a Rotation object is an actual rotation).

to_homogeneous(output=None)
Parameters:

output (...x(D+1)x(C+1) tensor or None) – optional tensor in which to store the result.

Returns:

A …x(D+1)x(C+1) tensor of homogeneous matrices representing the transformation, normalized with a last row equal to (0,…,0,1).

class Isometry(linear, translation)

An isometric transformation represented by an orthonormal and a translation part.

Variables:
  • linear – (…xDxD tensor): batch of matrices specifying the linear part.

  • translation – (…xD tensor): batch of matrices specifying the translation part.

apply(v)

Transforms a tensor of points coordinates. See Applying a transformation.

Parameters:

v (...xD tensor) – tensor of point coordinates to transform.

Returns:

The transformed point coordinates.

as_tuple()
Returns:

a tuple of tensors containing the linear and translation parts of the transformation respectively.

clone()
Returns:

A copy of the transformation (useful to avoid aliasing issues).

compose(other)

Compose a transformation with the current one.

Parameters:

other – an other transformation of same type.

Returns:

The resulting transformation.

classmethod from_homogeneous(matrix)

Instantiate a new transformation from an input homogeneous (D+1)x(C+1) matrix. The input matrix is assumed to be normalized and to satisfy the properties of the transformation. No checks are performed.

Parameters:

matrix (...x(D+1)x(C+1) tensor) – tensor of transformations expressed in homogeneous coordinates, normalized with a last row equal to (0,…,0,1).

Returns:

The corresponding transformation.

inverse()
Returns:

The inverse transformation, when applicable.

linear_apply(v)

Transforms a tensor of vector coordinates.

Parameters:

v (...xD tensor) – tensor of vector coordinates to transform.

Returns:

The transformed vector coordinates.

See note in apply() regarding broadcasting.

linear_compose(other)

Compose the linear part of two transformations.

Parameters:

other – an other transformation of same type.

Returns:

a tensor representing the composed transformation.

linear_inverse()
Returns:

The inverse of the linear transformation, when applicable.

linear_normalize()
Returns:

Linear transformation normalized to an orthonormal matrix (…xDxD tensor).

normalize()
Returns:

Copy of the transformation, normalized to ensure the class properties (for example to ensure that a Rotation object is an actual rotation).

to_homogeneous(output=None)
Parameters:

output (...x(D+1)x(C+1) tensor or None) – optional tensor in which to store the result.

Returns:

A …x(D+1)x(C+1) tensor of homogeneous matrices representing the transformation, normalized with a last row equal to (0,…,0,1).

class Linear(linear)

A linear transformation parameterized by a matrix \(M \in \mathcal{M}_{D,C}(\mathbb{R})\), transforming a point \(x \in \mathbb{R}^C\) into \(M x\).

Variables:

linear – (…xDxC tensor): batch of matrices specifying the transformations considered.

apply(v)

Transforms a tensor of points coordinates. See Applying a transformation.

Parameters:

v (...xD tensor) – tensor of point coordinates to transform.

Returns:

The transformed point coordinates.

clone()
Returns:

A copy of the transformation (useful to avoid aliasing issues).

compose(other)

Compose a transformation with the current one.

Parameters:

other – an other transformation of same type.

Returns:

The resulting transformation.

inverse()
Returns:

The inverse transformation, when applicable.

linear_apply(v)

Transforms a tensor of vector coordinates.

Parameters:

v (...xD tensor) – tensor of vector coordinates to transform.

Returns:

The transformed vector coordinates.

See note in apply() regarding broadcasting.

linear_compose(other)

Compose the linear part of two transformations.

Parameters:

other – an other transformation of same type.

Returns:

a tensor representing the composed transformation.

linear_inverse()
Returns:

The inverse of the linear transformation, when applicable.

normalize()
Returns:

Copy of the transformation, normalized to ensure the class properties (for example to ensure that a Rotation object is an actual rotation).

class Orthonormal(linear)

An orthogonal transformation represented by an orthonormal matrix \(M \in \mathcal{M}_{D,D}(\mathbb{R})\), transforming a point \(x \in \mathbb{R}^D\) into \(M x\).

Variables:

linear – (…xDxD tensor): batch of matrices \(M\) specifying the transformations considered.

apply(v)

Transforms a tensor of points coordinates. See Applying a transformation.

Parameters:

v (...xD tensor) – tensor of point coordinates to transform.

Returns:

The transformed point coordinates.

clone()
Returns:

A copy of the transformation (useful to avoid aliasing issues).

compose(other)

Compose a transformation with the current one.

Parameters:

other – an other transformation of same type.

Returns:

The resulting transformation.

inverse()
Returns:

The inverse transformation, when applicable.

linear_apply(v)

Transforms a tensor of vector coordinates.

Parameters:

v (...xD tensor) – tensor of vector coordinates to transform.

Returns:

The transformed vector coordinates.

See note in apply() regarding broadcasting.

linear_compose(other)

Compose the linear part of two transformations.

Parameters:

other – an other transformation of same type.

Returns:

a tensor representing the composed transformation.

linear_inverse()
Returns:

The inverse of the linear transformation, when applicable.

linear_normalize()
Returns:

Linear transformation normalized to an orthonormal matrix (…xDxD tensor).

normalize()
Returns:

Copy of the transformation, normalized to ensure the class properties (for example to ensure that a Rotation object is an actual rotation).

class Rigid(linear, translation)

A rigid transformation represented by an rotation and a translation part.

Variables:
  • linear – (…xDxD tensor): batch of matrices specifying the linear part.

  • translation – (…xD tensor): batch of matrices specifying the translation part.

apply(v)

Transforms a tensor of points coordinates. See Applying a transformation.

Parameters:

v (...xD tensor) – tensor of point coordinates to transform.

Returns:

The transformed point coordinates.

as_tuple()
Returns:

a tuple of tensors containing the linear and translation parts of the transformation respectively.

clone()
Returns:

A copy of the transformation (useful to avoid aliasing issues).

compose(other)

Compose a transformation with the current one.

Parameters:

other – an other transformation of same type.

Returns:

The resulting transformation.

classmethod from_homogeneous(matrix)

Instantiate a new transformation from an input homogeneous (D+1)x(C+1) matrix. The input matrix is assumed to be normalized and to satisfy the properties of the transformation. No checks are performed.

Parameters:

matrix (...x(D+1)x(C+1) tensor) – tensor of transformations expressed in homogeneous coordinates, normalized with a last row equal to (0,…,0,1).

Returns:

The corresponding transformation.

inverse()
Returns:

The inverse transformation, when applicable.

linear_apply(v)

Transforms a tensor of vector coordinates.

Parameters:

v (...xD tensor) – tensor of vector coordinates to transform.

Returns:

The transformed vector coordinates.

See note in apply() regarding broadcasting.

linear_compose(other)

Compose the linear part of two transformations.

Parameters:

other – an other transformation of same type.

Returns:

a tensor representing the composed transformation.

linear_inverse()
Returns:

The inverse of the linear transformation, when applicable.

linear_normalize()
Returns:

Linear transformation normalized to a rotation matrix (…xDxD tensor).

normalize()
Returns:

Copy of the transformation, normalized to ensure the class properties (for example to ensure that a Rotation object is an actual rotation).

to_homogeneous(output=None)
Parameters:

output (...x(D+1)x(C+1) tensor or None) – optional tensor in which to store the result.

Returns:

A …x(D+1)x(C+1) tensor of homogeneous matrices representing the transformation, normalized with a last row equal to (0,…,0,1).

to_rigidunitquat()

Returns the corresponding RigidUnitQuat transformation.

Note

Original and resulting transformations share the same translation tensor. Be careful in case of in-place modifications.

class RigidUnitQuat(linear, translation)

A rigid transformation represented by a unit quaternion and a translation part.

Variables:
  • linear – (…x4 tensor): batch of unit quaternions defining the rotation.

  • translation – (…x3 tensor): batch of matrices specifying the translation part.

Note

Quaternions are assumed to be of unit norm, for all internal operations. Use the normalize() method if needed.

apply(v)

Transforms a tensor of points coordinates. See Applying a transformation.

Parameters:

v (...xD tensor) – tensor of point coordinates to transform.

Returns:

The transformed point coordinates.

as_tuple()
Returns:

a tuple of tensors containing the linear and translation parts of the transformation respectively.

clone()
Returns:

A copy of the transformation (useful to avoid aliasing issues).

compose(other)

Compose a transformation with the current one.

Parameters:

other – an other transformation of same type.

Returns:

The resulting transformation.

static from_homogeneous(matrix)

Instantiate a new transformation from an input homogeneous (D+1)x(D+1) matrix.

Parameters:

matrix (...x(D+1)x(D+1) tensor) – tensor of transformations expressed in homogeneous coordinates, normalized with a last row equal to (0,…,0,1).

Returns:

The corresponding transformation.

Note

  • The input matrix is not tested to ensure that it satisfies the required properties of the transformation.

  • Components of the resulting transformation may consist in views of the input matrix. Be careful if you intend to modify it in-place.

inverse()
Returns:

The inverse transformation, when applicable.

linear_apply(v)

Transforms a tensor of vector coordinates.

Parameters:

v (...xD tensor) – tensor of vector coordinates to transform.

Returns:

The transformed vector coordinates.

See note in apply() regarding broadcasting.

linear_compose(other)

Compose the linear part of two transformations.

Parameters:

other – an other transformation of same type.

Returns:

a tensor representing the composed transformation.

linear_inverse()
Returns:

The inverse of the linear transformation, when applicable.

linear_normalize()
Returns:

Normalized unit quaternion (…x4 tensor).

normalize()
Returns:

Copy of the transformation, normalized to ensure the class properties (for example to ensure that a Rotation object is an actual rotation).

to_homogeneous(output=None)
Parameters:

output (...x4x4 tensor or None) – tensor in which to store the result (optional).

Returns:

A …x4x4 tensor of homogeneous matrices representing the transformation, normalized with a last row equal to (0,…,0,1).

to_rigid()

Returns the corresponding Rigid transformation.

Note

Original and resulting transformations share the same translation tensor. Be careful in case of in-place modifications.

class Rotation(linear)

A rotation represented by a rotation matrix \(R \in \mathcal{M}_{D,D}(\mathbb{R})\), transforming a point \(x \in \mathbb{R}^D\) into \(R x\).

Variables:

linear – (…xDxD tensor): batch of matrices \(R\) defining the rotation.

apply(v)

Transforms a tensor of points coordinates. See Applying a transformation.

Parameters:

v (...xD tensor) – tensor of point coordinates to transform.

Returns:

The transformed point coordinates.

clone()
Returns:

A copy of the transformation (useful to avoid aliasing issues).

compose(other)

Compose a transformation with the current one.

Parameters:

other – an other transformation of same type.

Returns:

The resulting transformation.

inverse()
Returns:

The inverse transformation, when applicable.

linear_apply(v)

Transforms a tensor of vector coordinates.

Parameters:

v (...xD tensor) – tensor of vector coordinates to transform.

Returns:

The transformed vector coordinates.

See note in apply() regarding broadcasting.

linear_compose(other)

Compose the linear part of two transformations.

Parameters:

other – an other transformation of same type.

Returns:

a tensor representing the composed transformation.

linear_inverse()
Returns:

The inverse of the linear transformation, when applicable.

linear_normalize()
Returns:

Linear transformation normalized to a rotation matrix (…xDxD tensor).

normalize()
Returns:

Copy of the transformation, normalized to ensure the class properties (for example to ensure that a Rotation object is an actual rotation).

class RotationUnitQuat(linear)

A 3D rotation represented by a unit quaternion.

Variables:

linear – (…x4 tensor, XYZW convention): batch of unit quaternions defining the rotation.

Note

Quaternions are assumed to be of unit norm, for all internal operations. Use normalize() if needed.

apply(v)

Transforms a tensor of points coordinates. See Applying a transformation.

Parameters:

v (...xD tensor) – tensor of point coordinates to transform.

Returns:

The transformed point coordinates.

clone()
Returns:

A copy of the transformation (useful to avoid aliasing issues).

compose(other)

Compose a transformation with the current one.

Parameters:

other – an other transformation of same type.

Returns:

The resulting transformation.

inverse()
Returns:

The inverse transformation, when applicable.

linear_apply(v)

Transforms a tensor of vector coordinates.

Parameters:

v (...xD tensor) – tensor of vector coordinates to transform.

Returns:

The transformed vector coordinates.

See note in apply() regarding broadcasting.

linear_compose(other)

Compose the linear part of two transformations.

Parameters:

other – an other transformation of same type.

Returns:

a tensor representing the composed transformation.

linear_inverse()
Returns:

The inverse of the linear transformation, when applicable.

linear_normalize()
Returns:

Normalized unit quaternion (…x4 tensor).

normalize()
Returns:

Copy of the transformation, normalized to ensure the class properties (for example to ensure that a Rotation object is an actual rotation).

Advanced

Running unit tests

from source repository:

python -m unittest

Building Sphinx documentation

From source repository:

./build_doc.sh

License

RoMa, Copyright (c) 2021 NAVER Corp., is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license (see license).

Bits of code were adapted from SciPy. Documentation is generated, distributed and displayed with the support of Sphinx and other materials (see notice).

Changelog

Version 1.4.5:
  • 3-Clause BSD Licensing.

Version 1.4.4:
Version 1.4.3:
Version 1.4.2:
  • Fix for quat_action() to support arbitrary devices and types.

  • Added conversion functions between Rigid and RigidUnitQuat.

Version 1.4.1:
Version 1.4.0:
Version 1.3.4:
Version 1.3.3:
Version 1.3.2:
Version 1.3.1:
Version 1.3.0:
Version 1.2.7:
  • Fix of unflatten_batch_dims() to ensure compatibility with PyTorch 1.6.0.

  • Fix of symmatrixvec_to_unitquat() that was not producing a lower triangular matrix.

Version 1.2.6:
Version 1.2.5:
  • Added an optional regularization argument for Procrustes orthonormalization.

  • Added a rigid registration example in the documentation.

Version 1.2.4:
  • Procrustes: automatic fallback to vanilla SVD decomposition for large dimensions.

Version 1.2.3:
  • Improved support for double precision tensors.

Version 1.2.2:
Version 1.2.1:
  • Open-source release.