RoMa: A lightweight library to deal with 3D rotations in PyTorch.
RoMa (which stands for Rotation Manipulation) provides differentiable mappings between 3D rotation representations, mappings from Euclidean to rotation space, and various utilities related to rotations. It is implemented in PyTorch and aims to be an easy-to-use and reasonably efficient toolbox for Machine Learning and gradient-based optimization.
Additionally, please cite this work if you use RoMa for your research:
@inproceedings{bregier2021deepregression,
title={Deep Regression on Manifolds: a {3D} Rotation Case Study},
author={Br{\'e}gier, Romain},
journal={2021 International Conference on 3D Vision (3DV)},
year={2021}
}
Installation
The easiest way to install RoMa is to use pip:
pip install roma
Alternatively one can install the latest version of RoMa directly from the source repository:
pip install git+https://github.com/naver/roma
With old pytorch versions (torch<1.8), we recommend installing torch-batch-svd
to achieve a significant speed-up with procrustes()
on CUDA GPUs (see section Why a new library?).
You can check that this module is properly loaded using the function is_torch_batch_svd_available()
.
With recent pytorch installations (torch>=1.8), torch-batch-svd is no longer needed or used.
Main features
Supported rotation representations
- Rotation vector (rotvec)
Encoded using a …x3 tensor.
3D vector angle * axis represents a rotation of angle angle (expressed in radians) around a unit 3D axis.
- Unit quaternion (unitquat)
Encoded as …x4 tensor.
Note
We use XYZW quaternion convention, i.e. components of quaternion \(x i + y j + z k + w\) are represented by the 4D vector \((x,y,z,w)\).
We assume unit quaternions to be of unit length, and do not perform implicit normalization.
- Rotation matrix (rotmat)
Encoded as a …xDxD tensor (D=3 for 3D rotations).
We use column-vector convention, i.e. \(R X\) is the transformation of a 1xD vector \(X\) by a rotation matrix \(R\).
- Euler angles and Tait-Bryan angles (euler)
Encoded as a …xD tensor or a list of D tensors corresponding to each angle (D=3 for typical Euler angles conventions).
We provide mappings between Euler angles and other rotation representations. Euler angles suffer from shortcomings such as gimbal lock, and we recommend using quaternions or rotation matrices to perform actual computations.
Mappings between rotation representations
RoMa provides functions to convert between rotation representations.
Example mapping a batch of rotation vectors into corresponding unit quaternions:
import torch, roma
batch_shape = (3, 2)
rotvec = torch.randn(batch_shape + (3,))
q = roma.rotvec_to_unitquat(rotvec)
Mappings from Euclidean to 3D rotation space
Mapping an arbitrary tensor to a valid rotation can be useful e.g. for Machine Learning applications. While rotation vectors or Euler angles can be used for such purpose, they suffer from various shortcomings, and we therefore provide the following alternative mappings:
special_gramschmidt()
Mapping from a 3x2 tensor to 3x3 rotation matrix, using special Gram-Schmidt orthonormalization (6D representation, popularized by Zhou et al.).
special_procrustes()
Mapping from a nxn arbitrary matrix to a nxn rotation matrix, using special orthogonal Procrustes orthonormalization.
symmatrixvec_to_unitquat()
Mapping from a 10D vector to an antipodal pair of quaternion through eigenvector decomposition of a 4x4 symmetric matrix, proposed by Peretroukhin et al..
For general purpose applications, we recommend the use of special_procrustes()
which projects an arbitrary square matrix onto the closest matrix of the rotation space,
considering Frobenius norm. Please refer to this paper for more insights.
Example mapping random 3x3 matrices to valid rotation matrices:
import torch, roma
batch_shape = (5,)
M = torch.randn(batch_shape + (3,3))
R = roma.special_procrustes(M)
assert roma.is_rotation_matrix(R, epsilon=1e-5)
Support for an arbitrary number of batch dimensions
For convenience, functions accept an arbitrary number of batch dimensions:
import torch, roma
print(roma.rotvec_to_rotmat(torch.randn(3)).shape) # -> torch.Size([3, 3])
print(roma.rotvec_to_rotmat(torch.randn(5, 3)).shape) # -> torch.Size([5, 3, 3])
print(roma.rotvec_to_rotmat(torch.randn(2, 5, 3)).shape) # -> torch.Size([2, 5, 3, 3])
Quaternion operations
import torch, roma
q = torch.randn(4) # Random unnormalized quaternion
qconv = roma.quat_conjugation(q) # Quaternion conjugation
qinv = roma.quat_inverse(q) # Quaternion inverse
print(roma.quat_product(q, qinv)) # -> [0,0,0,1] identity quaternion
Rotation composition and inverse
Example using rotation vector representation:
import torch, roma
rotvecs = [roma.random_rotvec() for _ in range(3)] # Random rotation vectors
r = roma.rotvec_composition(rotvecs) # Composition of an arbitrary number of rotations
rinv = roma.rotvec_inverse(r) # Rotation vector corresponding to the inverse rotation
Rotation metrics
RoMa implements some usual similarity measures over the 3D rotation space:
import torch, roma
R1, R2 = roma.random_rotmat(size=5), roma.random_rotmat(size=5)
theta = roma.rotmat_geodesic_distance(R1, R2) # In radian
cos_theta = roma.rotmat_cosine_angle(R1.transpose(-2, -1) @ R2)
Weighted rotation averaging
special_procrustes()
can be used to easily average rotations:
import torch, roma
n = 5
R_i = roma.random_rotmat(n) # Batch of n 3x3 rotation matrices to average
w_i = torch.rand(n) # Weight of each matrix, between 0 and 1
M = torch.sum(w_i[:,None, None] * R_i, dim=0) # 3x3 matrix
R = roma.special_procrustes(M) # weighted average.
To be precise, it consists in the Fréchet mean considering the chordal distance.
Note
The same average could be performed using quaternion representation and symmatrix mapping (slower batched implementation on GPU).
Rigid registration
rigid_points_registration()
and rigid_vectors_registration()
enable to align ordered sets of points/vectors:
import torch, roma
R_gt = roma.random_rotmat()
t_gt = torch.randn(1, 3)
src = torch.randn(100, 3) # source points / vectors
target = src @ R_gt.T # target vectors
R_predicted = roma.rigid_vectors_registration(src, target)
print(f"R_gt\n{R_gt}")
print(f"R_predicted\n{R_predicted}")
target = src @ R_gt.T + t_gt # target points
R_predicted, t_predicted = roma.rigid_points_registration(src, target)
print(f"R_gt\n{R_gt}")
print(f"R_predicted\n{R_predicted}")
print(f"t_gt\n{t_gt}")
print(f"t_predicted\n{t_predicted}")
Spherical linear interpolation (SLERP)
SLERP between batches of unit quaternions:
import torch, roma
q0, q1 = roma.random_unitquat(size=10), roma.random_unitquat(size=10)
steps = torch.linspace(0, 1.0, 5)
q_interpolated = roma.utils.unitquat_slerp(q0, q1, steps)
idx = 1 # Print interpolations for an arbitrary element of the batch
print('q0:\n', q0[idx])
print('q1:\n', q1[idx])
print('q_interpolated:\n', q_interpolated[:,idx])
SLERP between rotation vectors (shortest path interpolation):
import torch, roma
steps = torch.linspace(0, 1.0, 5)
rotvec0, rotvec1 = torch.randn(3), torch.randn(3)
rotvec_interpolated = roma.rotvec_slerp(rotvec0, rotvec1, steps)
Why a new library?
- We could not find a PyTorch library satisfying our needs, so we built our own.
We wanted a reliable, easy-to-use and efficient toolbox to deal with rotation representations in PyTorch. While Kornia provides some utility functions to deal with 3D rotations, it included several major bugs at the time of writting (early 2021) (see e.g. https://github.com/kornia/kornia/issues/723 or https://github.com/kornia/kornia/issues/317).
- Care for numerical precision
RoMa is implemented with numerical precision in mind, e.g. with a special handling of small angle rotation vectors or through the choice of appropriate algorithms.
As an example, below is plotted a function that takes as input an angle \(\theta\), produces a rotation matrix \(R_z(\theta)\) of angle \(\theta\) and estimates its geodesic distance with respect to the identity matrix, using 32 bits floating point arithmetic. We observe that
rotmat_geodesic_distance()
is much more precise for this task than an other implementation often found in academic code:rotmat_geodesic_distance_naive()
. Backward pass throughrotmat_geodesic_distance_naive()
leads to unstable gradient estimations and produces Not-a-Number values for small angles, whereasrotmat_geodesic_distance_naive()
is well-behaved, and returns Not-a-Number only for 0.0 angle where gradient is mathematically undefined.- Computation efficiency
RoMa favors code clarity, but aims to be reasonably efficient.
In particular, for Procrustes orthonormalization it can use on NVidia GPUs a batched SVD decomposition that provides orders of magnitude speed-ups for large batch sizes compared to vanilla
torch.svd()
for PyTorch versions below 1.8. The plot below was obtained for random 3x3 matrices, with PyTorch 1.7, a NVidia Tesla T4 GPU and CUDA 11.0. Note that recent versions of pytorch (>=1.8) integrate such speed up off-the-shelf.- Syntactic sugar
RoMa aims to be easy-to-use with a simple syntax, and support for an arbitrary number of batch dimensions to let users focus on their applications.
API Documentation
Mappings
Various mappings between different rotation representations.
- procrustes(M, force_rotation=False, regularization=0.0, gradient_eps=1e-05, return_singular_values: bool = False)
Returns the orthonormal matrix \(R\) minimizing Frobenius norm \(\| M - R \|_F\).
- Parameters:
M (...xNxN tensor) – batch of square matrices.
force_rotation (bool) – if True, forces the output to be a rotation matrix.
regularization (float >= 0) – weight of a regularization term added to the gradient. Using this regularization is equivalent to adding a term \(regularization * \| M - R \|_F^2\) to the training loss function.
gradient_eps (float > 0) – small value used to enforce numerical stability during backpropagation.
- Returns:
batch of orthonormal matrices (…xNxN tensor) and optional singular values. For advanced users, singular values of the SVD decomposition with sign flipping (… tensor) can optionally be returned by setting the argument
return_singular_values
toTrue
.
- procrustes_naive(M, force_rotation: bool = False, return_singular_values: bool = False)
Implementation of
procrustes()
relying on default backward pass of autograd and SVD decomposition. Could be slightly less stable thanprocrustes()
.
- quat_wxyz_to_xyzw(wxyz)
Convert quaternion from WXYZ to XYZW convention.
- Parameters:
wxyz (...x4 tensor, WXYZ convention) – batch of quaternions.
- Returns:
batch of quaternions (…x4 tensor, XYZW convention).
- quat_xyzw_to_wxyz(xyzw)
Convert quaternion from XYZW to WXYZ convention.
- Parameters:
xyzw (...x4 tensor, XYZW convention) – batch of quaternions.
- Returns:
batch of quaternions (…x4 tensor, WXYZ convention).
- rotmat_to_rotvec(R)
Converts rotation matrix to rotation vector representation.
- Parameters:
R (...x3x3 tensor) – batch of rotation matrices.
- Returns:
batch of rotation vectors (…x3 tensor).
- rotmat_to_unitquat(R)
Converts rotation matrix to unit quaternion representation.
- Parameters:
R (...x3x3 tensor) – batch of rotation matrices.
- Returns:
batch of unit quaternions (…x4 tensor, XYZW convention).
- rotvec_to_rotmat(rotvec: Tensor, epsilon=1e-06) Tensor
Converts rotation vector to rotation matrix representation. Conversion uses Rodrigues formula in general, and a first order approximation for small angles.
- Parameters:
rotvec (...x3 tensor) – batch of rotation vectors.
epsilon (float) – small angle threshold.
- Returns:
batch of rotation matrices (…x3x3 tensor).
- rotvec_to_unitquat(rotvec)
Converts rotation vector into unit quaternion representation.
- Parameters:
rotvec (...x3 tensor) – batch of rotation vectors.
- Returns:
batch of unit quaternions (…x4 tensor, XYZW convention).
- special_gramschmidt(M, epsilon=0)
Returns the 3x3 rotation matrix obtained by Gram-Schmidt orthonormalization of two 3D input vectors (first two columns of input matrix M).
- Parameters:
M (...x3xN tensor) – batch of 3xN matrices, with N >= 2. Only the first two columns of the matrices are used for orthonormalization.
epsilon (float >= 0) – optional clamping value to avoid returning Not-a-Number values in case of ill-defined input.
- Returns:
batch of rotation matrices (…x3x3 tensor).
Warning
In case of ill-defined input (colinear input column vectors), the output will not be a rotation matrix.
- special_procrustes(M, regularization=0.0, gradient_eps=1e-05, return_singular_values: bool = False)
Returns the rotation matrix \(R\) minimizing Frobenius norm \(\| M - R \|_F\).
- Parameters:
M (...xNxN tensor) – batch of square matrices.
regularization (float >= 0) – weight of a regularization term added to the gradient. Using this regularization is equivalent to adding a term \(regularization * \| M - R \|_F^2\) to the training loss function.
gradient_eps (float > 0) – small value used to enforce numerical stability during backpropagation.
- Returns:
batch of rotation matrices (…xNxN tensor). For advanced users, singular values of the SVD decomposition with sign flipping (… tensor) can optionally be returned by setting the argument
return_singular_values
toTrue
.
- special_procrustes_naive(M, return_singular_values: bool = False)
Implementation of
special_procrustes()
relying on default backward pass of autograd and SVD decomposition. Could be slightly less stable thanspecial_procrustes()
.
- symmatrix_to_projective_point(A)
Converts a DxD symmetric matrix A into a projective point represented by a unit vector \(q\) minimizing \(q^T A q\). Only the lower part of the matrix is considered in practice.
- Parameters:
A (...xDxD tensor) – batch of symmetric matrices. Only the lower triangular part is considered.
- Returns:
batch of unit vectors \(q\) (…xD tensor).
- Reference:
Peretroukhin, M. Giamou, D. M. Rosen, W. N. Greene, N. Roy, and J. Kelly, “A Smooth Representation of Belief over SO(3) for Deep Rotation Learning with Uncertainty,” 2020, doi: 10.15607/RSS.2020.XVI.007.
Warning
This mapping is unstable when the smallest eigenvalue of A has a multiplicity strictly greater than 1.
The eigenvalue decomposition may fail, in particular when using single precision numbers.
Current implementation is rather slow due to the implementation of
torch.symeig
. The CuSolver library provides a faster eigenvalue decomposition alternative, but results where found to be unreliable.
- symmatrixvec_to_unitquat(x)
Converts a 10D vector into a unit quaternion representation. Based on
symmatrix_to_projective_point()
.- Parameters:
x (...x10 tensor) – batch of 10D vectors.
- Returns:
batch of unit quaternions (…x4 tensor, XYZW convention).
- Reference:
Peretroukhin, M. Giamou, D. M. Rosen, W. N. Greene, N. Roy, and J. Kelly, “A Smooth Representation of Belief over SO(3) for Deep Rotation Learning with Uncertainty,” 2020, doi: 10.15607/RSS.2020.XVI.007.
- unitquat_to_rotmat(quat)
Converts unit quaternion into rotation matrix representation.
- Parameters:
quat (...x4 tensor, XYZW convention) – batch of unit quaternions. No normalization is applied before computation.
- Returns:
batch of rotation matrices (…x3x3 tensor).
- unitquat_to_rotvec(quat, shortest_arc=True)
Converts unit quaternion into rotation vector representation.
Based on the representation of a rotation of angle \({\theta}\) and unit axis \((x,y,z)\) by the unit quaternions \(\pm [\sin({\theta} / 2) (x i + y j + z k) + \cos({\theta} / 2)]\).
- Parameters:
quat (...x4 tensor, XYZW convention) – batch of unit quaternions. No normalization is applied before computation.
shortest_arc (bool) – if True, the function returns the smallest rotation vectors corresponding to the input 3D rotations, i.e. rotation vectors with a norm smaller than \(\pi\). If False, the function may return rotation vectors of norm larger than \(\pi\), depending on the sign of the input quaternions.
- Returns:
batch of rotation vectors (…x3 tensor).
Note
Behavior is undefined for inputs
quat=torch.as_tensor([0.0, 0.0, 0.0, -1.0])
andshortest_arc=False
, as any rotation vector of angle \(2 \pi\) could be a valid representation in such case.
- euler_to_rotmat(convention: str, angles, degrees=False, dtype=None, device=None)
Convert Euler angles to rotation matrix representation.
- Parameters:
convention (string) – ‘xyz’ for example. See
euler_to_unitquat()
.angles (...xD tensor, or tuple/list of D floats or ... tensors) – a list of angles associated to each axis, expressed in radians by default.
degrees (bool) – if True, input angles are assumed to be expressed in degrees.
- Returns:
a batch of rotation matrices (…x3x3 tensor).
- euler_to_rotvec(convention: str, angles, degrees=False, dtype=None, device=None)
Convert Euler angles to rotation vector representation.
- Parameters:
convention (string) – ‘xyz’ for example. See
euler_to_unitquat()
.angles (...xD tensor, or tuple/list of D floats or ... tensors) – a list of angles associated to each axis, expressed in radians by default.
degrees (bool) – if True, input angles are assumed to be expressed in degrees.
- Returns:
a batch of rotation vectors (…x3 tensor).
- euler_to_unitquat(convention: str, angles, degrees=False, normalize=True, dtype=None, device=None)
Convert Euler angles to unit quaternion representation.
- Parameters:
convention (string) – string defining a sequence of D rotation axes (‘XYZ’ or ‘xzx’ for example). The sequence of rotation is expressed either with respect to a global ‘extrinsic’ coordinate system (in which case axes are denoted in lowercase: ‘x’, ‘y’, or ‘z’), or with respect to an ‘intrinsic’ coordinates system attached to the object under rotation (in which case axes are denoted in uppercase: ‘X’, ‘Y’, ‘Z’). Intrinsic and extrinsic conventions cannot be mixed.
angles (...xD tensor, or tuple/list of D floats or ... tensors) – a list of angles associated to each axis, expressed in radians by default.
degrees (bool) – if True, input angles are assumed to be expressed in degrees.
normalize (bool) – if True, normalize the returned quaternion to compensate potential numerical.
- Returns:
A batch of unit quaternions (…x4 tensor, XYZW convention).
Warning
Case is important: ‘xyz’ and ‘XYZ’ denote different conventions.
- rotmat_to_euler(convention: str, rotmat, as_tuple=False, degrees=False, epsilon=1e-07)
Convert rotation matrix to Euler angles representation.
- Parameters:
convention (str) – string of 3 characters belonging to {‘x’, ‘y’, ‘z’} for extrinsic rotations, or {‘X’, ‘Y’, ‘Z’} for intrinsic rotations. Consecutive axes should not be identical.
rotmat (...x3x3 tensor) – input batch of rotation matrices.
as_tuple (boolean) – if True, angles are not stacked but returned as a tuple of tensors.
degrees (bool) – if True, angles are returned in degrees.
epsilon (float) – a small value used to detect degenerate configurations.
- Returns:
A stacked …x3 tensor corresponding to Euler angles, expressed by default in radians. In case of gimbal lock, the third angle is arbitrarily set to 0.
- rotvec_to_euler(convention: str, rotvec, as_tuple=False, degrees=False, epsilon=1e-07)
Convert rotation vector to Euler angles representation.
- Parameters:
convention (str) – string of 3 characters belonging to {‘x’, ‘y’, ‘z’} for extrinsic rotations, or {‘X’, ‘Y’, ‘Z’} for intrinsic rotations. Consecutive axes should not be identical.
rotvec (...x3 tensor) – input batch of rotation vectors.
as_tuple (boolean) – if True, angles are not stacked but returned as a tuple of tensors.
degrees (bool) – if True, angles are returned in degrees.
epsilon (float) – a small value used to detect degenerate configurations.
- Returns:
A stacked …x3 tensor corresponding to Euler angles, expressed by default in radians. In case of gimbal lock, the third angle is arbitrarily set to 0.
- unitquat_to_euler(convention: str, quat, as_tuple=False, degrees=False, epsilon=1e-07)
Convert unit quaternion to Euler angles representation.
- Parameters:
convention (str) – string of 3 characters belonging to {‘x’, ‘y’, ‘z’} for extrinsic rotations, or {‘X’, ‘Y’, ‘Z’} for intrinsic rotations. Consecutive axes should not be identical.
quat (...x4 tensor, XYZW convention) – input batch of unit quaternion.
as_tuple (boolean) – if True, angles are not stacked but returned as a tuple of tensors.
degrees (bool) – if True, angles are returned in degrees.
epsilon (float) – a small value used to detect degenerate configurations.
- Returns:
A stacked …x3 tensor corresponding to Euler angles, expressed by default in radians. In case of gimbal lock, the third angle is arbitrarily set to 0.
Utils
Various utility functions related to rotation representations.
- identity_quat(size=(), dtype=torch.float32, device=None)
Return a batch of identity unit quaternions.
- Parameters:
size (tuple or int) – batch size. Use for example
tuple()
to generate a single element, and(5,2)
to generate a 5x2 batch.- Returns:
batch of identity quaternions (size x 4 tensor, XYZW convention).
Note
All returned batch quaternions refer to the same memory location. Consider cloning the output tensor prior performing any in-place operations.
- is_orthonormal_matrix(R, epsilon=1e-07)
Test if matrices are orthonormal.
- Parameters:
R (...xDxD tensor) – batch of square matrices.
epsilon – tolerance threshold.
- Returns:
boolean tensor (shape …).
- is_rotation_matrix(R, epsilon=1e-07)
Test if matrices are rotation matrices.
- Parameters:
R (...xDxD tensor) – batch of square matrices.
epsilon – tolerance threshold.
- Returns:
boolean tensor (shape …).
- is_torch_batch_svd_available() bool
Returns True if the module ‘torch_batch_svd’ has been loaded. Returns False otherwise.
- quat_action(q, v, is_normalized=False)
Rotate a 3D vector \(v=(x,y,z)\) by a rotation represented by a quaternion q.
Based on the action by conjugation \(q,v : q v q^{-1}\), considering the pure quaternion \(v=xi + yj +zk\) by abuse of notation.
- Parameters:
q (...x4 tensor, XYZW convention) – batch of quaternions.
v (...x3 tensor) – batch of 3D vectors.
is_normalized – use True if the input quaternions are already normalized, to avoid unnecessary computations.
- Returns:
batch of rotated 3D vectors (…x3 tensor).
Note
One should favor rotation matrix representation to rotate multiple vectors by the same rotation efficiently.
- quat_composition(sequence, normalize=False)
Returns the product of a sequence of quaternions.
- Parameters:
sequence (sequence of ...x4 tensors, XYZW convention) – sequence of batches of quaternions.
normalize (bool) – it True, normalize the returned quaternion.
- Returns:
batch of quaternions (…x4 tensor, XYZW convention).
- quat_conjugation(quat)
Returns the conjugation of input batch of quaternions.
- Parameters:
quat (...x4 tensor, XYZW convention) – batch of quaternions.
- Returns:
batch of quaternions (…x4 tensor, XYZW convention).
Note
Conjugation of a unit quaternion is equal to its inverse.
- quat_inverse(quat)
Returns the inverse of a batch of quaternions.
- Parameters:
quat (...x4 tensor, XYZW convention) – batch of quaternions.
- Returns:
batch of quaternions (…x4 tensor, XYZW convention).
Note
Inverse of null quaternion is undefined.
For unit quaternions, consider using conjugation instead.
- quat_normalize(quat)
Returns a normalized, unit norm, copy of a batch of quaternions.
- Parameters:
quat (...x4 tensor, XYZW convention) – batch of quaternions.
- Returns:
batch of quaternions (…x4 tensor, XYZW convention).
- quat_product(p, q)
Returns the product of two quaternions.
- Parameters:
p (...x4 tensor, XYZW convention) – batch of quaternions.
q (...x4 tensor, XYZW convention) – batch of quaternions.
- Returns:
batch of quaternions (…x4 tensor, XYZW convention).
- random_rotmat(size=(), dtype=torch.float32, device=None)
Generates a batch of random 3x3 rotation matrices, uniformly sampled according to the usual rotation metric.
- Parameters:
size (tuple or int) – batch size. Use for example
tuple()
to generate a single element, and(5,2)
to generate a 5x2 batch.- Returns:
batch of rotation matrices (size x 3x3 tensor).
- random_rotvec(size=(), dtype=torch.float32, device=None)
Generates a batch of random rotation vectors, uniformly sampled according to the usual rotation metric.
- Parameters:
size (tuple or int) – batch size. Use for example
tuple()
to generate a single element, and(5,2)
to generate a 5x2 batch.- Returns:
batch of rotation vectors (size x 3 tensor).
- random_unitquat(size=(), dtype=torch.float32, device=None)
Generates a batch of random unit quaternions, uniformly sampled according to the usual quaternion metric.
- Parameters:
size (tuple or int) – batch size. Use for example
tuple()
to generate a single element, and(5,2)
to generate a 5x2 batch.- Returns:
batch of unit quaternions (size x 4 tensor).
- Reference:
Shoemake, “Uniform Random Rotations”, in Graphics Gems III (IBM Version), Elsevier, 1992, pp. 124–132. doi: 10.1016/B978-0-08-050755-2.50036-1.
- rigid_points_registration(x, y, weights=None, compute_scaling=False)
Returns the rigid transformation \((R,t)\) and the optional scaling \(s\) that best align an input list of points \((x_i)_{i=1...n}\) to a target list of points \((y_i)_{i=1...n}\), by minimizing the sum of square distance \(\sum_i w_i \|s R x_i + t - y_i\|^2\), where \((w_i)_{i=1...n}\) denotes optional positive weights. This is sometimes referred to as the Kabsch/Umeyama algorithm.
- Parameters:
x (...xNxD tensor) – list of N points of dimension D.
y (...xNxD tensor) – list of corresponding target points.
weights (None or ...xN tensor) – optional list of weights associated to each point.
- Returns:
a triplet \((R, t, s)\) consisting of a rotation matrix \(R\) (…xDxD tensor), a translation vector \(t\) (…xD tensor), and a scaling \(s\) (… tensor) if
compute_scaling=True
. Returns \((R, t)\) otherwise.
References
Umeyama, “Least-squares estimation of transformation parameters between two point patterns,” IEEE Transactions on pattern analysis and machine intelligence, vol. 13, no. 4, Art. no. 4, 1991.
Kabsch, “A solution for the best rotation to relate two sets of vectors”. Acta Crystallographica, A32, 1976.
- rigid_vectors_registration(x, y, weights=None, compute_scaling=False)
Returns the rotation matrix \(R\) and the optional scaling \(s\) that best align an input list of vectors \((x_i)_{i=1...n}\) to a target list of vectors \((y_i)_{i=1...n}\) by minimizing the sum of square distance \(\sum_i w_i \|s R x_i - y_i\|^2\), where \((w_i)_{i=1...n}\) denotes optional positive weights. See
rigid_points_registration()
for details.- Parameters:
x (...xNxD tensor) – list of N vectors of dimension D.
y (...xNxD tensor) – list of corresponding target vectors.
weights (None or ...xN tensor) – optional list of weights associated to each vector.
- Returns:
A tuple \((R, s)\) consisting of the rotation matrix \(R\) (…xDxD tensor) and the scaling \(s\) (… tensor) if
compute_scaling=True
. Returns the rotation matrix \(R\) otherwise.
- rotmat_composition(sequence, normalize=False)
Returns the product of a sequence of rotation matrices.
- Parameters:
sequence (sequence of ...xNxN tensors) – sequence of batches of rotation matrices.
normalize – if True, apply special Procrustes orthonormalization to compensate for numerical errors.
- Returns:
batch of rotation matrices (…xNxN tensor).
- rotmat_cosine_angle(R)
Returns the cosine angle of the input 3x3 rotation matrix R. Based on the equality \(Trace(R) = 1 + 2 cos(alpha)\).
- Parameters:
R (...x3x3 tensor) – batch of 3w3 rotation matrices.
- Returns:
batch of cosine angles (… tensor).
- rotmat_geodesic_distance(R1, R2, clamping=1.0)
Returns the angular distance alpha between a pair of rotation matrices. Based on the equality \(|R_2 - R_1|_F = 2 \sqrt{2} sin(alpha/2)\).
- Parameters:
R1 (...x3x3 tensor) – batch of 3x3 rotation matrices.
R2 (...x3x3 tensor) – batch of 3x3 rotation matrices.
clamping – clamping value applied to the input of
torch.asin()
. Use 1.0 to ensure valid angular distances. Use a value strictly smaller than 1.0 to ensure finite gradients.
- Returns:
batch of angles in radians (… tensor).
- rotmat_geodesic_distance_naive(R1, R2)
Returns the angular distance between a pair of rotation matrices. Based on
rotmat_cosine_angle()
and less precise thanrotmat_geodesic_distance()
for nearby rotations.- Parameters:
R1 (...x3x3 tensor) – batch of 3x3 rotation matrices.
R2 (...x3x3 tensor) – batch of 3x3 rotation matrices.
- Returns:
batch of angles in radians (… tensor).
- rotmat_inverse(R)
Returns the inverse of a rotation matrix.
- Parameters:
R (...xNxN tensor) – batch of rotation matrices.
- Returns:
batch of inverted rotation matrices (…xNxN tensor).
Warning
The function returns a transposed view of the input, therefore one should be careful with in-place operations.
- rotmat_slerp(R0, R1, steps)
Spherical linear interpolation between two rotation matrices.
- Parameters:
R0 (Ax3x3 tensor) – batch of rotation matrices (A may contain multiple dimensions).
R1 (Ax3x3 tensor) – batch of rotation matrices (A may contain multiple dimensions).
steps (tensor of shape B) – interpolation steps, 0.0 corresponding to R0 and 1.0 to R1 (B may contain multiple dimensions).
- Returns:
batch of interpolated rotation matrices (BxAx3x3 tensor).
- rotvec_composition(sequence, normalize=False)
Returns a rotation vector corresponding to the composition of a sequence of rotations represented by rotation vectors. Composition is performed using an intermediary quaternion representation.
- Parameters:
sequence (sequence of ...x3 tensors) – sequence of batches of rotation vectors.
normalize (bool) – if True, normalize intermediary representation to compensate for numerical errors.
- rotvec_geodesic_distance(vec1, vec2)
Returns the angular distance between rotations represented by rotation vectors. (use a conversion to unit quaternions internally).
- Parameters:
vec1 (...x3 tensors) – batch of rotation vectors.
vec2 (...x3 tensors) – batch of rotation vectors.
- Returns:
batch of angles in radians (… tensor).
- rotvec_inverse(rotvec)
Returns the inverse of the input rotation expressed using rotation vector representation.
- Parameters:
rotvec (...x3 tensor) – batch of rotation vectors.
- Returns:
batch of rotation vectors (…x3 tensor).
- rotvec_slerp(rotvec0, rotvec1, steps)
Spherical linear interpolation between two rotation vector representations.
- Parameters:
rotvec0 (Ax3 tensor) – batch of rotation vectors (A may contain multiple dimensions).
rotvec1 (Ax3 tensor) – batch of rotation vectors (A may contain multiple dimensions).
steps (tensor of shape B) – interpolation steps, 0.0 corresponding to rotvec0 and 1.0 to rotvec1 (B may contain multiple dimensions).
- Returns:
batch of interpolated rotation vectors (BxAx3 tensor).
- unitquat_geodesic_distance(q1, q2)
Returns the angular distance alpha between rotations represented by unit quaternions. Based on the equality \(min |q_2 \pm q_1| = 2 |sin(alpha/4)|\).
- Parameters:
q1 (...x4 tensor, XYZW convention) – batch of unit quaternions.
q2 (...x4 tensor, XYZW convention) – batch of unit quaternions.
- Returns:
batch of angles in radians (… tensor).
- unitquat_slerp(q0, q1, steps, shortest_arc=True)
Spherical linear interpolation between two unit quaternions.
- Parameters:
q0 (Ax4 tensor) – batch of unit quaternions (A may contain multiple dimensions).
q1 (Ax4 tensor) – batch of unit quaternions (A may contain multiple dimensions).
steps (tensor of shape B) – interpolation steps, 0.0 corresponding to q0 and 1.0 to q1 (B may contain multiple dimensions).
shortest_arc (boolean) – if True, interpolation will be performed along the shortest arc on SO(3) from q0 to q1 or -q1.
- Returns:
batch of interpolated quaternions (BxAx4 tensor).
Note
When considering quaternions as rotation representations, one should keep in mind that spherical interpolation is not necessarily performed along the shortest arc, depending on the sign of
torch.sum(q0*q1,dim=-1)
.Behavior is undefined when using
shortest_arc=False
with antipodal quaternions.
- unitquat_slerp_fast(q0, q1, steps, shortest_arc=True)
Spherical linear interpolation between two unit quaternions. This function requires less computations than
roma.utils.unitquat_slerp()
, but is unsuitable for extrapolation (i.e.steps
must be within [0,1]).- Parameters:
q0 (Ax4 tensor) – batch of unit quaternions (A may contain multiple dimensions).
q1 (Ax4 tensor) – batch of unit quaternions (A may contain multiple dimensions).
steps (tensor of shape B) – interpolation steps within 0.0 and 1.0, 0.0 corresponding to q0 and 1.0 to q1 (B may contain multiple dimensions).
shortest_arc (boolean) – if True, interpolation will be performed along the shortest arc on SO(3) from q0 to q1 or -q1.
- Returns:
batch of interpolated quaternions (BxAx4 tensor).
Spatial transformations
Spatial transformations parameterized by rotation matrices, unit quaternions and more.
Example of use
import torch, roma
# Rigid transformation parameterized by a rotation matrix and a translation vector
T0 = roma.Rigid(linear=roma.random_rotmat(), translation=torch.randn(3))
# Rigid transformations parameterized by a unit quaternion and a translation vector
T1 = roma.RigidUnitQuat(linear=roma.random_unitquat(), translation=torch.randn(3))
T2 = roma.RigidUnitQuat(linear=roma.random_unitquat(), translation=torch.randn(3))
# Inverting and composing transformations
T = (T1.inverse() @ T2)
# Normalization to ensure that T is actually a rigid transformation.
T = T.normalize()
# Direct access to the translation part
T.translation += 0.5
# Transformation of points:
points = torch.randn(100,3)
# Adjusting the shape of T for proper broadcasting.
transformed_points = T[None].apply(points)
# Transformation of vectors:
vectors = torch.randn(10,20,3)
# Adjusting the shape of T for proper broadcasting.
transformed_vectors = T[None,None].linear_apply(vectors)
# Casting the transformation into an homogeneous 4x4 matrix.
M = T.to_homogeneous()
Applying a transformation
When applying a transformation to a set of points of coordinates v
,
The batch shape of v
should be broadcastable with the batch shape of the transformation.
For example, one can sample a unique random rigid 3D transformation and use it to transform 100 random 3D points as follows:
roma.Rigid(roma.random_rotmat(), torch.randn(3))[None].apply(torch.randn(100,3))
To apply a different transformation to each point, one could use instead:
roma.Rigid(roma.random_rotmat(100), torch.randn(100,3)).apply(torch.randn(100,3))
Aliasing issues
Warning
For efficiency reasons, transformation objects do not copy input data. Be careful if you intend to do some in-place data modifications, and use the clone()
method when required.
- class Affine(linear, translation)
An affine transformation represented by a linear and a translation part.
- Variables:
linear – (…xCxD tensor): batch of matrices specifying the linear part.
translation – (…xD tensor or None): batch of matrices specifying the translation part.
- apply(v)
Transforms a tensor of points coordinates. See Applying a transformation.
- Parameters:
v (...xD tensor) – tensor of point coordinates to transform.
- Returns:
The transformed point coordinates.
- as_tuple()
- Returns:
a tuple of tensors containing the linear and translation parts of the transformation respectively.
- clone()
- Returns:
A copy of the transformation (useful to avoid aliasing issues).
- compose(other)
Compose a transformation with the current one.
- Parameters:
other – an other transformation of same type.
- Returns:
The resulting transformation.
- classmethod from_homogeneous(matrix)
Instantiate a new transformation from an input homogeneous (D+1)x(C+1) matrix. The input matrix is assumed to be normalized and to satisfy the properties of the transformation. No checks are performed.
- Parameters:
matrix (...x(D+1)x(C+1) tensor) – tensor of transformations expressed in homogeneous coordinates, normalized with a last row equal to (0,…,0,1).
- Returns:
The corresponding transformation.
- inverse()
- Returns:
The inverse transformation, when applicable.
- linear_apply(v)
Transforms a tensor of vector coordinates.
- Parameters:
v (...xD tensor) – tensor of vector coordinates to transform.
- Returns:
The transformed vector coordinates.
See note in
apply()
regarding broadcasting.
- linear_compose(other)
Compose the linear part of two transformations.
- Parameters:
other – an other transformation of same type.
- Returns:
a tensor representing the composed transformation.
- linear_inverse()
- Returns:
The inverse of the linear transformation, when applicable.
- normalize()
- Returns:
Copy of the transformation, normalized to ensure the class properties (for example to ensure that a
Rotation
object is an actual rotation).
- squeeze(dim)
Return a view of the transformation in which a batch dimension equal to 1 has been squeezed.
- Variables:
dim – positive integer: The dimension to squeeze.
- to_homogeneous(output=None)
- Parameters:
output (...x(D+1)x(C+1) tensor or None) – optional tensor in which to store the result.
- Returns:
A …x(D+1)x(C+1) tensor of homogeneous matrices representing the transformation, normalized with a last row equal to (0,…,0,1).
- class Isometry(linear, translation)
An isometric transformation represented by an orthonormal and a translation part.
- Variables:
linear – (…xDxD tensor or None): batch of matrices specifying the linear part.
translation – (…xD tensor or None): batch of matrices specifying the translation part.
- classmethod Identity(dim, batch_shape=(), dtype=torch.float32, device=None)
Return a default identity transformation.
- Variables:
dim – (strictly positive integer): dimension of the space in which the transformation operates (e.g. dim=3 for 3D transformations).
batch_shape – (tuple): batch dimensions considered.
- apply(v)
Transforms a tensor of points coordinates. See Applying a transformation.
- Parameters:
v (...xD tensor) – tensor of point coordinates to transform.
- Returns:
The transformed point coordinates.
- as_tuple()
- Returns:
a tuple of tensors containing the linear and translation parts of the transformation respectively.
- clone()
- Returns:
A copy of the transformation (useful to avoid aliasing issues).
- compose(other)
Compose a transformation with the current one.
- Parameters:
other – an other transformation of same type.
- Returns:
The resulting transformation.
- classmethod from_homogeneous(matrix)
Instantiate a new transformation from an input homogeneous (D+1)x(C+1) matrix. The input matrix is assumed to be normalized and to satisfy the properties of the transformation. No checks are performed.
- Parameters:
matrix (...x(D+1)x(C+1) tensor) – tensor of transformations expressed in homogeneous coordinates, normalized with a last row equal to (0,…,0,1).
- Returns:
The corresponding transformation.
- inverse()
- Returns:
The inverse transformation, when applicable.
- linear_apply(v)
Transforms a tensor of vector coordinates.
- Parameters:
v (...xD tensor) – tensor of vector coordinates to transform.
- Returns:
The transformed vector coordinates.
See note in
apply()
regarding broadcasting.
- linear_compose(other)
Compose the linear part of two transformations.
- Parameters:
other – an other transformation of same type.
- Returns:
a tensor representing the composed transformation.
- linear_inverse()
- Returns:
The inverse of the linear transformation, when applicable.
- linear_normalize()
- Returns:
Linear transformation normalized to an orthonormal matrix (…xDxD tensor).
- normalize()
- Returns:
Copy of the transformation, normalized to ensure the class properties (for example to ensure that a
Rotation
object is an actual rotation).
- squeeze(dim)
Return a view of the transformation in which a batch dimension equal to 1 has been squeezed.
- Variables:
dim – positive integer: The dimension to squeeze.
- to_homogeneous(output=None)
- Parameters:
output (...x(D+1)x(C+1) tensor or None) – optional tensor in which to store the result.
- Returns:
A …x(D+1)x(C+1) tensor of homogeneous matrices representing the transformation, normalized with a last row equal to (0,…,0,1).
- class Linear(linear)
A linear transformation parameterized by a matrix \(M \in \mathcal{M}_{D,C}(\mathbb{R})\), transforming a point \(x \in \mathbb{R}^C\) into \(M x\).
- Variables:
linear – (…xDxC tensor): batch of matrices specifying the transformations considered.
- apply(v)
Transforms a tensor of points coordinates. See Applying a transformation.
- Parameters:
v (...xD tensor) – tensor of point coordinates to transform.
- Returns:
The transformed point coordinates.
- clone()
- Returns:
A copy of the transformation (useful to avoid aliasing issues).
- compose(other)
Compose a transformation with the current one.
- Parameters:
other – an other transformation of same type.
- Returns:
The resulting transformation.
- inverse()
- Returns:
The inverse transformation, when applicable.
- linear_apply(v)
Transforms a tensor of vector coordinates.
- Parameters:
v (...xD tensor) – tensor of vector coordinates to transform.
- Returns:
The transformed vector coordinates.
See note in
apply()
regarding broadcasting.
- linear_compose(other)
Compose the linear part of two transformations.
- Parameters:
other – an other transformation of same type.
- Returns:
a tensor representing the composed transformation.
- linear_inverse()
- Returns:
The inverse of the linear transformation, when applicable.
- class Orthonormal(linear)
An orthogonal transformation represented by an orthonormal matrix \(M \in \mathcal{M}_{D,D}(\mathbb{R})\), transforming a point \(x \in \mathbb{R}^D\) into \(M x\).
- Variables:
linear – (…xDxD tensor): batch of matrices \(M\) specifying the transformations considered.
- apply(v)
Transforms a tensor of points coordinates. See Applying a transformation.
- Parameters:
v (...xD tensor) – tensor of point coordinates to transform.
- Returns:
The transformed point coordinates.
- clone()
- Returns:
A copy of the transformation (useful to avoid aliasing issues).
- compose(other)
Compose a transformation with the current one.
- Parameters:
other – an other transformation of same type.
- Returns:
The resulting transformation.
- inverse()
- Returns:
The inverse transformation, when applicable.
- linear_apply(v)
Transforms a tensor of vector coordinates.
- Parameters:
v (...xD tensor) – tensor of vector coordinates to transform.
- Returns:
The transformed vector coordinates.
See note in
apply()
regarding broadcasting.
- linear_compose(other)
Compose the linear part of two transformations.
- Parameters:
other – an other transformation of same type.
- Returns:
a tensor representing the composed transformation.
- linear_inverse()
- Returns:
The inverse of the linear transformation, when applicable.
- linear_normalize()
- Returns:
Linear transformation normalized to an orthonormal matrix (…xDxD tensor).
- class Rigid(linear, translation)
A rigid transformation represented by an rotation and a translation part.
- Variables:
linear – (…xDxD tensor or None): batch of matrices specifying the linear part.
translation – (…xD tensor or None): batch of matrices specifying the translation part.
- classmethod Identity(dim, batch_shape=(), dtype=torch.float32, device=None)
Return a default identity transformation.
- Variables:
dim – (strictly positive integer): dimension of the space in which the transformation operates (e.g. dim=3 for 3D transformations).
batch_shape – (tuple): batch dimensions considered.
- apply(v)
Transforms a tensor of points coordinates. See Applying a transformation.
- Parameters:
v (...xD tensor) – tensor of point coordinates to transform.
- Returns:
The transformed point coordinates.
- as_tuple()
- Returns:
a tuple of tensors containing the linear and translation parts of the transformation respectively.
- clone()
- Returns:
A copy of the transformation (useful to avoid aliasing issues).
- compose(other)
Compose a transformation with the current one.
- Parameters:
other – an other transformation of same type.
- Returns:
The resulting transformation.
- classmethod from_homogeneous(matrix)
Instantiate a new transformation from an input homogeneous (D+1)x(C+1) matrix. The input matrix is assumed to be normalized and to satisfy the properties of the transformation. No checks are performed.
- Parameters:
matrix (...x(D+1)x(C+1) tensor) – tensor of transformations expressed in homogeneous coordinates, normalized with a last row equal to (0,…,0,1).
- Returns:
The corresponding transformation.
- inverse()
- Returns:
The inverse transformation, when applicable.
- linear_apply(v)
Transforms a tensor of vector coordinates.
- Parameters:
v (...xD tensor) – tensor of vector coordinates to transform.
- Returns:
The transformed vector coordinates.
See note in
apply()
regarding broadcasting.
- linear_compose(other)
Compose the linear part of two transformations.
- Parameters:
other – an other transformation of same type.
- Returns:
a tensor representing the composed transformation.
- linear_inverse()
- Returns:
The inverse of the linear transformation, when applicable.
- linear_normalize()
- Returns:
Linear transformation normalized to a rotation matrix (…xDxD tensor).
- normalize()
- Returns:
Copy of the transformation, normalized to ensure the class properties (for example to ensure that a
Rotation
object is an actual rotation).
- squeeze(dim)
Return a view of the transformation in which a batch dimension equal to 1 has been squeezed.
- Variables:
dim – positive integer: The dimension to squeeze.
- to_homogeneous(output=None)
- Parameters:
output (...x(D+1)x(C+1) tensor or None) – optional tensor in which to store the result.
- Returns:
A …x(D+1)x(C+1) tensor of homogeneous matrices representing the transformation, normalized with a last row equal to (0,…,0,1).
- to_rigidunitquat()
Returns the corresponding RigidUnitQuat transformation.
Note
Original and resulting transformations share the same translation tensor. Be careful in case of in-place modifications.
- class RigidUnitQuat(linear, translation)
A rigid transformation represented by a unit quaternion and a translation part.
- Variables:
linear – (…x4 tensor): batch of unit quaternions defining the rotation.
translation – (…x3 tensor): batch of matrices specifying the translation part.
Note
Quaternions are assumed to be of unit norm, for all internal operations. Use the
normalize()
method if needed.- apply(v)
Transforms a tensor of points coordinates. See Applying a transformation.
- Parameters:
v (...xD tensor) – tensor of point coordinates to transform.
- Returns:
The transformed point coordinates.
- as_tuple()
- Returns:
a tuple of tensors containing the linear and translation parts of the transformation respectively.
- clone()
- Returns:
A copy of the transformation (useful to avoid aliasing issues).
- compose(other)
Compose a transformation with the current one.
- Parameters:
other – an other transformation of same type.
- Returns:
The resulting transformation.
- static from_homogeneous(matrix)
Instantiate a new transformation from an input homogeneous (D+1)x(D+1) matrix.
- Parameters:
matrix (...x(D+1)x(D+1) tensor) – tensor of transformations expressed in homogeneous coordinates, normalized with a last row equal to (0,…,0,1).
- Returns:
The corresponding transformation.
Note
The input matrix is not tested to ensure that it satisfies the required properties of the transformation.
Components of the resulting transformation may consist in views of the input matrix. Be careful if you intend to modify it in-place.
- inverse()
- Returns:
The inverse transformation, when applicable.
- linear_apply(v)
Transforms a tensor of vector coordinates.
- Parameters:
v (...xD tensor) – tensor of vector coordinates to transform.
- Returns:
The transformed vector coordinates.
See note in
apply()
regarding broadcasting.
- linear_compose(other)
Compose the linear part of two transformations.
- Parameters:
other – an other transformation of same type.
- Returns:
a tensor representing the composed transformation.
- linear_inverse()
- Returns:
The inverse of the linear transformation, when applicable.
- linear_normalize()
- Returns:
Normalized unit quaternion (…x4 tensor).
- normalize()
- Returns:
Copy of the transformation, normalized to ensure the class properties (for example to ensure that a
Rotation
object is an actual rotation).
- squeeze(dim)
Return a view of the transformation in which a batch dimension equal to 1 has been squeezed.
- Variables:
dim – positive integer: The dimension to squeeze.
- to_homogeneous(output=None)
- Parameters:
output (...x4x4 tensor or None) – tensor in which to store the result (optional).
- Returns:
A …x4x4 tensor of homogeneous matrices representing the transformation, normalized with a last row equal to (0,…,0,1).
- to_rigid()
Returns the corresponding Rigid transformation.
Note
Original and resulting transformations share the same translation tensor. Be careful in case of in-place modifications.
- class Rotation(linear)
A rotation represented by a rotation matrix \(R \in \mathcal{M}_{D,D}(\mathbb{R})\), transforming a point \(x \in \mathbb{R}^D\) into \(R x\).
- Variables:
linear – (…xDxD tensor): batch of matrices \(R\) defining the rotation.
- apply(v)
Transforms a tensor of points coordinates. See Applying a transformation.
- Parameters:
v (...xD tensor) – tensor of point coordinates to transform.
- Returns:
The transformed point coordinates.
- clone()
- Returns:
A copy of the transformation (useful to avoid aliasing issues).
- compose(other)
Compose a transformation with the current one.
- Parameters:
other – an other transformation of same type.
- Returns:
The resulting transformation.
- inverse()
- Returns:
The inverse transformation, when applicable.
- linear_apply(v)
Transforms a tensor of vector coordinates.
- Parameters:
v (...xD tensor) – tensor of vector coordinates to transform.
- Returns:
The transformed vector coordinates.
See note in
apply()
regarding broadcasting.
- linear_compose(other)
Compose the linear part of two transformations.
- Parameters:
other – an other transformation of same type.
- Returns:
a tensor representing the composed transformation.
- linear_inverse()
- Returns:
The inverse of the linear transformation, when applicable.
- linear_normalize()
- Returns:
Linear transformation normalized to a rotation matrix (…xDxD tensor).
- class RotationUnitQuat(linear)
A 3D rotation represented by a unit quaternion.
- Variables:
linear – (…x4 tensor, XYZW convention): batch of unit quaternions defining the rotation.
Note
Quaternions are assumed to be of unit norm, for all internal operations. Use
normalize()
if needed.- apply(v)
Transforms a tensor of points coordinates. See Applying a transformation.
- Parameters:
v (...xD tensor) – tensor of point coordinates to transform.
- Returns:
The transformed point coordinates.
- clone()
- Returns:
A copy of the transformation (useful to avoid aliasing issues).
- compose(other)
Compose a transformation with the current one.
- Parameters:
other – an other transformation of same type.
- Returns:
The resulting transformation.
- inverse()
- Returns:
The inverse transformation, when applicable.
- linear_apply(v)
Transforms a tensor of vector coordinates.
- Parameters:
v (...xD tensor) – tensor of vector coordinates to transform.
- Returns:
The transformed vector coordinates.
See note in
apply()
regarding broadcasting.
- linear_compose(other)
Compose the linear part of two transformations.
- Parameters:
other – an other transformation of same type.
- Returns:
a tensor representing the composed transformation.
- linear_inverse()
- Returns:
The inverse of the linear transformation, when applicable.
- linear_normalize()
- Returns:
Normalized unit quaternion (…x4 tensor).
Advanced
Running unit tests
from source repository:
python -m unittest
Building Sphinx documentation
From source repository:
./build_doc.sh
License
RoMa, Copyright (c) 2020 NAVER Corp., is licensed under the 3-Clause BSD License (see license).
Bits of code were adapted from SciPy. Documentation is generated, distributed and displayed with the support of Sphinx and other materials (see notice).
Changelog
- Version 1.5.1:
Syntactic sugar for Spatial transformations: support for default linear or translation parts, identity transformations and batch dimension squeezing.
- Version 1.5.0:
Added Euler angles mappings.
- Version 1.4.5:
3-Clause BSD Licensing.
- Version 1.4.4:
Added
identity_quat()
.
- Version 1.4.3:
Fix normalization bug in
quat_composition()
(thanks jamiesalter for reporting).
- Version 1.4.2:
Fix for
quat_action()
to support arbitrary devices and types.Added conversion functions between Rigid and RigidUnitQuat.
- Version 1.4.1:
Added XYZW / WXYZ quaternion conversion routines:
quat_xyzw_to_wxyz()
andquat_wxyz_to_xyzw()
.Added
rotvec_geodesic_distance()
.
- Version 1.4.0:
Added the Spatial transformations module.
- Version 1.3.4:
Use default torch.svd with pytorch versions greater than 1.8 (efficiency issue compared to torch_batch_svd solved in this PR: https://github.com/pytorch/pytorch/pull/48436).
- Version 1.3.3:
procrustes()
can optionally return singular values, for advanced uses.rigid_points_registration()
andrigid_vectors_registration()
can optionally return scaling estimations.Added
unitquat_geodesic_distance()
.
- Version 1.3.2:
Simplified backpropagation of
procrustes()
.Support for optional weights in
rigid_points_registration()
andrigid_vectors_registration()
.Fix for
random_unitquat()
to support initialization on arbitrary device.
- Version 1.3.1:
Removed spurious code in
procrustes()
.Replaced warning about missing ‘torch_batch_svd’ module by a test function:
is_torch_batch_svd_available()
.Improved documentation and tests.
- Version 1.3.0:
Added
roma.utils.quat_action()
.Change of underlying algorithm for
random_unitquat()
to avoid potential divisions by 0.Fix of
roma.utils.unitquat_slerp()
which was always performing interpolation along the shortest arc regardless of the value of theshortest_path
argument (renamedshortest_arc
in the new version).
- Version 1.2.7:
Fix of
unflatten_batch_dims()
to ensure compatibility with PyTorch 1.6.0.Fix of
symmatrixvec_to_unitquat()
that was not producing a lower triangular matrix.
- Version 1.2.6:
Added an optional regularization argument to
special_procrustes()
.Added an optional clamping argument to
rotmat_geodesic_distance()
.Fix:
rotvec_to_rotmat()
no longer produces nonfinite gradient for null rotation vectors.
- Version 1.2.5:
Added an optional regularization argument for Procrustes orthonormalization.
Added a rigid registration example in the documentation.
- Version 1.2.4:
Procrustes: automatic fallback to vanilla SVD decomposition for large dimensions.
- Version 1.2.3:
Improved support for double precision tensors.
- Version 1.2.2:
Added
rigid_points_registration()
andrigid_vectors_registration()
.Added
rotmat_slerp()
.Circumvented a deprecation warning with
torch.symeig()
when using recent PyTorch versions.
- Version 1.2.1:
Open-source release.