Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 504f545

Browse filesBrowse files
authored
Added constrained uniform sampling of angles for SE3, SO3, and Quaternions (bdaiinstitute#139)
1 parent cf115f6 commit 504f545
Copy full SHA for 504f545

File tree

Expand file treeCollapse file tree

5 files changed

+148
-19
lines changed
Filter options
Expand file treeCollapse file tree

5 files changed

+148
-19
lines changed

‎spatialmath/base/quaternions.py

Copy file name to clipboardExpand all lines: spatialmath/base/quaternions.py
+98-12Lines changed: 98 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414
import math
1515
import numpy as np
1616
import spatialmath.base as smb
17+
from spatialmath.base.argcheck import getunit
1718
from spatialmath.base.types import *
19+
import scipy.interpolate as interpolate
20+
from typing import Optional
21+
from functools import lru_cache
1822

1923
_eps = np.finfo(np.float64).eps
2024

21-
2225
def qeye() -> QuaternionArray:
2326
"""
2427
Create an identity quaternion
@@ -843,29 +846,112 @@ def qslerp(
843846
return q0
844847

845848

846-
def qrand() -> UnitQuaternionArray:
849+
def _compute_cdf_sin_squared(theta: float):
847850
"""
848-
Random unit-quaternion
851+
Computes the CDF for the distribution of angular magnitude for uniformly sampled rotations.
852+
853+
:arg theta: angular magnitude
854+
:rtype: float
855+
:return: cdf of a given angular magnitude
856+
:rtype: float
857+
858+
Helper function for uniform sampling of rotations with constrained angular magnitude.
859+
This function returns the integral of the pdf of angular magnitudes (2/pi * sin^2(theta/2)).
860+
"""
861+
return (theta - np.sin(theta)) / np.pi
849862

863+
@lru_cache(maxsize=1)
864+
def _generate_inv_cdf_sin_squared_interp(num_interpolation_points: int = 256) -> interpolate.interp1d:
865+
"""
866+
Computes an interpolation function for the inverse CDF of the distribution of angular magnitude.
867+
868+
:arg num_interpolation_points: number of points to use in the interpolation function
869+
:rtype: int
870+
:return: interpolation function for the inverse cdf of a given angular magnitude
871+
:rtype: interpolate.interp1d
872+
873+
Helper function for uniform sampling of rotations with constrained angular magnitude.
874+
This function returns interpolation function for the inverse of the integral of the
875+
pdf of angular magnitudes (2/pi * sin^2(theta/2)), which is not analytically defined.
876+
"""
877+
cdf_sin_squared_interp_angles = np.linspace(0, np.pi, num_interpolation_points)
878+
cdf_sin_squared_interp_values = _compute_cdf_sin_squared(cdf_sin_squared_interp_angles)
879+
return interpolate.interp1d(cdf_sin_squared_interp_values, cdf_sin_squared_interp_angles)
880+
881+
def _compute_inv_cdf_sin_squared(x: ArrayLike, num_interpolation_points: int = 256) -> ArrayLike:
882+
"""
883+
Computes the inverse CDF of the distribution of angular magnitude.
884+
885+
:arg x: value for cdf of angular magnitudes
886+
:rtype: ArrayLike
887+
:arg num_interpolation_points: number of points to use in the interpolation function
888+
:rtype: int
889+
:return: angular magnitude associate with cdf value
890+
:rtype: ArrayLike
891+
892+
Helper function for uniform sampling of rotations with constrained angular magnitude.
893+
This function returns the angle associated with the cdf value derived form integral of
894+
the pdf of angular magnitudes (2/pi * sin^2(theta/2)), which is not analytically defined.
895+
"""
896+
inv_cdf_sin_squared_interp = _generate_inv_cdf_sin_squared_interp(num_interpolation_points)
897+
return inv_cdf_sin_squared_interp(x)
898+
899+
def qrand(theta_range:Optional[ArrayLike2] = None, unit: str = "rad", num_interpolation_points: int = 256) -> UnitQuaternionArray:
900+
"""
901+
Random unit-quaternion
902+
903+
:arg theta_range: angular magnitude range [min,max], defaults to None.
904+
:type xrange: 2-element sequence, optional
905+
:arg unit: angular units: 'rad' [default], or 'deg'
906+
:type unit: str
907+
:arg num_interpolation_points: number of points to use in the interpolation function
908+
:rtype: int
909+
:arg num_interpolation_points: number of points to use in the interpolation function
910+
:rtype: int
850911
:return: random unit-quaternion
851912
:rtype: ndarray(4)
852913
853-
Computes a uniformly distributed random unit-quaternion which can be
854-
considered equivalent to a random SO(3) rotation.
914+
Computes a uniformly distributed random unit-quaternion, with in a maximum
915+
angular magnitude, which can be considered equivalent to a random SO(3) rotation.
855916
856917
.. runblock:: pycon
857918
858919
>>> from spatialmath.base import qrand, qprint
859920
>>> qprint(qrand())
860921
"""
861-
u = np.random.uniform(low=0, high=1, size=3) # get 3 random numbers in [0,1]
862-
return np.r_[
863-
math.sqrt(1 - u[0]) * math.sin(2 * math.pi * u[1]),
864-
math.sqrt(1 - u[0]) * math.cos(2 * math.pi * u[1]),
865-
math.sqrt(u[0]) * math.sin(2 * math.pi * u[2]),
866-
math.sqrt(u[0]) * math.cos(2 * math.pi * u[2]),
867-
]
922+
if theta_range is not None:
923+
theta_range = getunit(theta_range, unit)
924+
925+
if(theta_range[0] < 0 or theta_range[1] > np.pi or theta_range[0] > theta_range[1]):
926+
ValueError('Invalid angular range. Must be within the range[0, pi].'
927+
+ f' Recieved {theta_range}.')
928+
929+
# Sample axis and angle independently, respecting the CDF of the
930+
# angular magnitude under uniform sampling.
931+
932+
# Sample angle using inverse transform sampling based on CDF
933+
# of the angular distribution (2/pi * sin^2(theta/2))
934+
theta = _compute_inv_cdf_sin_squared(
935+
np.random.uniform(
936+
low=_compute_cdf_sin_squared(theta_range[0]),
937+
high=_compute_cdf_sin_squared(theta_range[1]),
938+
),
939+
num_interpolation_points=num_interpolation_points,
940+
)
941+
# Sample axis uniformly using 3D normal distributed
942+
v = np.random.randn(3)
943+
v /= np.linalg.norm(v)
868944

945+
return np.r_[math.cos(theta / 2), (math.sin(theta / 2) * v)]
946+
else:
947+
u = np.random.uniform(low=0, high=1, size=3) # get 3 random numbers in [0,1]
948+
return np.r_[
949+
math.sqrt(1 - u[0]) * math.sin(2 * math.pi * u[1]),
950+
math.sqrt(1 - u[0]) * math.cos(2 * math.pi * u[1]),
951+
math.sqrt(u[0]) * math.sin(2 * math.pi * u[2]),
952+
math.sqrt(u[0]) * math.cos(2 * math.pi * u[2]),
953+
]
954+
869955

870956
def qmatrix(q: ArrayLike4) -> R4x4:
871957
"""

‎spatialmath/pose3d.py

Copy file name to clipboardExpand all lines: spatialmath/pose3d.py
+14-4Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from spatialmath.twist import Twist3
3636

37-
from typing import TYPE_CHECKING
37+
from typing import TYPE_CHECKING, Optional
3838
if TYPE_CHECKING:
3939
from spatialmath.quaternion import UnitQuaternion
4040

@@ -455,12 +455,16 @@ def Rz(cls, theta, unit: str = "rad") -> Self:
455455
return cls([smb.rotz(x, unit=unit) for x in smb.getvector(theta)], check=False)
456456

457457
@classmethod
458-
def Rand(cls, N: int = 1) -> Self:
458+
def Rand(cls, N: int = 1, *, theta_range:Optional[ArrayLike2] = None, unit: str = "rad") -> Self:
459459
"""
460460
Construct a new SO(3) from random rotation
461461
462462
:param N: number of random rotations
463463
:type N: int
464+
:param theta_range: angular magnitude range [min,max], defaults to None.
465+
:type xrange: 2-element sequence, optional
466+
:param unit: angular units: 'rad' [default], or 'deg'
467+
:type unit: str
464468
:return: SO(3) rotation matrix
465469
:rtype: SO3 instance
466470
@@ -477,7 +481,7 @@ def Rand(cls, N: int = 1) -> Self:
477481
478482
:seealso: :func:`spatialmath.quaternion.UnitQuaternion.Rand`
479483
"""
480-
return cls([smb.q2r(smb.qrand()) for _ in range(0, N)], check=False)
484+
return cls([smb.q2r(smb.qrand(theta_range=theta_range, unit=unit)) for _ in range(0, N)], check=False)
481485

482486
@overload
483487
@classmethod
@@ -1517,6 +1521,8 @@ def Rand(
15171521
xrange: Optional[ArrayLike2] = (-1, 1),
15181522
yrange: Optional[ArrayLike2] = (-1, 1),
15191523
zrange: Optional[ArrayLike2] = (-1, 1),
1524+
theta_range:Optional[ArrayLike2] = None,
1525+
unit: str = "rad",
15201526
) -> SE3: # pylint: disable=arguments-differ
15211527
"""
15221528
Create a random SE(3)
@@ -1527,6 +1533,10 @@ def Rand(
15271533
:type yrange: 2-element sequence, optional
15281534
:param zrange: z-axis range [min,max], defaults to [-1, 1]
15291535
:type zrange: 2-element sequence, optional
1536+
:param theta_range: angular magnitude range [min,max], defaults to None -> [0,pi].
1537+
:type xrange: 2-element sequence, optional
1538+
:param unit: angular units: 'rad' [default], or 'deg'
1539+
:type unit: str
15301540
:param N: number of random transforms
15311541
:type N: int
15321542
:return: SE(3) matrix
@@ -1557,7 +1567,7 @@ def Rand(
15571567
Z = np.random.uniform(
15581568
low=zrange[0], high=zrange[1], size=N
15591569
) # random values in the range
1560-
R = SO3.Rand(N=N)
1570+
R = SO3.Rand(N=N, theta_range=theta_range, unit=unit)
15611571
return cls(
15621572
[smb.transl(x, y, z) @ smb.r2t(r.A) for (x, y, z, r) in zip(X, Y, Z, R)],
15631573
check=False,

‎spatialmath/quaternion.py

Copy file name to clipboardExpand all lines: spatialmath/quaternion.py
+6-2Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,12 +1225,16 @@ def Rz(cls, angles: ArrayLike, unit: Optional[str] = "rad") -> UnitQuaternion:
12251225
)
12261226

12271227
@classmethod
1228-
def Rand(cls, N: int = 1) -> UnitQuaternion:
1228+
def Rand(cls, N: int = 1, *, theta_range:Optional[ArrayLike2] = None, unit: str = "rad") -> UnitQuaternion:
12291229
"""
12301230
Construct a new random unit quaternion
12311231
12321232
:param N: number of random rotations
12331233
:type N: int
1234+
:param theta_range: angular magnitude range [min,max], defaults to None -> [0,pi].
1235+
:type xrange: 2-element sequence, optional
1236+
:param unit: angular units: 'rad' [default], or 'deg'
1237+
:type unit: str
12341238
:return: random unit-quaternion
12351239
:rtype: UnitQuaternion instance
12361240
@@ -1248,7 +1252,7 @@ def Rand(cls, N: int = 1) -> UnitQuaternion:
12481252
12491253
:seealso: :meth:`UnitQuaternion.Rand`
12501254
"""
1251-
return cls([smb.qrand() for i in range(0, N)], check=False)
1255+
return cls([smb.qrand(theta_range=theta_range, unit=unit) for i in range(0, N)], check=False)
12521256

12531257
@classmethod
12541258
def Eul(cls, *angles: List[float], unit: Optional[str] = "rad") -> UnitQuaternion:

‎tests/test_pose3d.py

Copy file name to clipboardExpand all lines: tests/test_pose3d.py
+17-1Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,19 @@ def test_constructor(self):
7272
array_compare(R, np.eye(3))
7373
self.assertIsInstance(R, SO3)
7474

75+
np.random.seed(32)
7576
# random
7677
R = SO3.Rand()
7778
nt.assert_equal(len(R), 1)
7879
self.assertIsInstance(R, SO3)
7980

81+
# random constrained
82+
R = SO3.Rand(theta_range=(0.1, 0.7))
83+
self.assertIsInstance(R, SO3)
84+
self.assertEqual(R.A.shape, (3, 3))
85+
self.assertLessEqual(R.angvec()[0], 0.7)
86+
self.assertGreaterEqual(R.angvec()[0], 0.1)
87+
8088
# copy constructor
8189
R = SO3.Rx(pi / 2)
8290
R2 = SO3(R)
@@ -816,12 +824,13 @@ def test_constructor(self):
816824
array_compare(R, np.eye(4))
817825
self.assertIsInstance(R, SE3)
818826

827+
np.random.seed(65)
819828
# random
820829
R = SE3.Rand()
821830
nt.assert_equal(len(R), 1)
822831
self.assertIsInstance(R, SE3)
823832

824-
# random
833+
# random
825834
T = SE3.Rand()
826835
R = T.R
827836
t = T.t
@@ -847,6 +856,13 @@ def test_constructor(self):
847856
nt.assert_equal(TT.y, ones * t[1])
848857
nt.assert_equal(TT.z, ones * t[2])
849858

859+
# random constrained
860+
T = SE3.Rand(theta_range=(0.1, 0.7))
861+
self.assertIsInstance(T, SE3)
862+
self.assertEqual(T.A.shape, (4, 4))
863+
self.assertLessEqual(T.angvec()[0], 0.7)
864+
self.assertGreaterEqual(T.angvec()[0], 0.1)
865+
850866
# copy constructor
851867
R = SE3.Rx(pi / 2)
852868
R2 = SE3(R)

‎tests/test_quaternion.py

Copy file name to clipboardExpand all lines: tests/test_quaternion.py
+13Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,19 @@ def test_constructor_variants(self):
4848
nt.assert_array_almost_equal(
4949
UnitQuaternion.Rz(-90, "deg").vec, np.r_[1, 0, 0, -1] / math.sqrt(2)
5050
)
51+
52+
np.random.seed(73)
53+
q = UnitQuaternion.Rand(theta_range=(0.1, 0.7))
54+
self.assertIsInstance(q, UnitQuaternion)
55+
self.assertLessEqual(q.angvec()[0], 0.7)
56+
self.assertGreaterEqual(q.angvec()[0], 0.1)
57+
58+
59+
q = UnitQuaternion.Rand(theta_range=(0.1, 0.7))
60+
self.assertIsInstance(q, UnitQuaternion)
61+
self.assertLessEqual(q.angvec()[0], 0.7)
62+
self.assertGreaterEqual(q.angvec()[0], 0.1)
63+
5164

5265
def test_constructor(self):
5366
qcompare(UnitQuaternion(), [1, 0, 0, 0])

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.