Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 59873f1

Browse filesBrowse files
authored
compute mean of a set of rotations (#160)
1 parent 2306628 commit 59873f1
Copy full SHA for 59873f1

File tree

Expand file treeCollapse file tree

4 files changed

+223
-29
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+223
-29
lines changed

‎spatialmath/pose3d.py

Copy file name to clipboardExpand all lines: spatialmath/pose3d.py
+33-10Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -843,22 +843,22 @@ def Exp(
843843

844844
def UnitQuaternion(self) -> UnitQuaternion:
845845
"""
846-
SO3 as a unit quaternion instance
846+
SO3 as a unit quaternion instance
847847
848-
:return: a unit quaternion representation
849-
:rtype: UnitQuaternion instance
848+
:return: a unit quaternion representation
849+
:rtype: UnitQuaternion instance
850850
851-
``R.UnitQuaternion()`` is an ``UnitQuaternion`` instance representing the same rotation
852-
as the SO3 rotation ``R``.
851+
``R.UnitQuaternion()`` is an ``UnitQuaternion`` instance representing the same rotation
852+
as the SO3 rotation ``R``.
853853
854-
Example:
854+
Example:
855855
856-
.. runblock:: pycon
856+
.. runblock:: pycon
857857
858-
>>> from spatialmath import SO3
859-
>>> SO3.Rz(0.3).UnitQuaternion()
858+
>>> from spatialmath import SO3
859+
>>> SO3.Rz(0.3).UnitQuaternion()
860860
861-
"""
861+
"""
862862
# Function level import to avoid circular dependencies
863863
from spatialmath import UnitQuaternion
864864

@@ -931,6 +931,29 @@ def angdist(self, other: SO3, metric: int = 6) -> Union[float, ndarray]:
931931
else:
932932
return ad
933933

934+
def mean(self, tol: float = 20) -> SO3:
935+
"""Mean of a set of rotations
936+
937+
:param tol: iteration tolerance in units of eps, defaults to 20
938+
:type tol: float, optional
939+
:return: the mean rotation
940+
:rtype: :class:`SO3` instance.
941+
942+
Computes the Karcher mean of the set of rotations within the SO(3) instance.
943+
944+
:references:
945+
- `**Hartley, Trumpf** - "Rotation Averaging" - IJCV 2011 <https://users.cecs.anu.edu.au/~hartley/Papers/PDF/Hartley-Trumpf:Rotation-averaging:IJCV.pdf>`_, Algorithm 1, page 15.
946+
- `Karcher mean <https://en.wikipedia.org/wiki/Karcher_mean>`_
947+
"""
948+
949+
eta = tol * np.finfo(float).eps
950+
R_mean = self[0] # initial guess
951+
while True:
952+
r = np.dstack((R_mean.inv() * self).log()).mean(axis=2)
953+
if np.linalg.norm(r) < eta:
954+
return R_mean
955+
R_mean = R_mean @ self.Exp(r) # update estimate and normalize
956+
934957

935958
# ============================== SE3 =====================================#
936959

‎spatialmath/quaternion.py

Copy file name to clipboardExpand all lines: spatialmath/quaternion.py
+64-14Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ def __init__(self, s: Any = None, v=None, check: Optional[bool] = True):
4545
r"""
4646
Construct a new quaternion
4747
48-
:param s: scalar
49-
:type s: float
50-
:param v: vector
51-
:type v: 3-element array_like
48+
:param s: scalar part
49+
:type s: float or ndarray(N)
50+
:param v: vector part
51+
:type v: ndarray(3), ndarray(Nx3)
5252
5353
- ``Quaternion()`` constructs a zero quaternion
5454
- ``Quaternion(s, v)`` construct a new quaternion from the scalar ``s``
@@ -78,7 +78,7 @@ def __init__(self, s: Any = None, v=None, check: Optional[bool] = True):
7878
super().__init__()
7979

8080
if s is None and smb.isvector(v, 4):
81-
v,s = (s,v)
81+
v, s = (s, v)
8282

8383
if v is None:
8484
# single argument
@@ -92,6 +92,11 @@ def __init__(self, s: Any = None, v=None, check: Optional[bool] = True):
9292
# Quaternion(s, v)
9393
self.data = [np.r_[s, smb.getvector(v)]]
9494

95+
elif (
96+
smb.isvector(s) and smb.ismatrix(v, (None, 3)) and s.shape[0] == v.shape[0]
97+
):
98+
# Quaternion(s, v) where s and v are arrays
99+
self.data = [np.r_[_s, _v] for _s, _v in zip(s, v)]
95100
else:
96101
raise ValueError("bad argument to Quaternion constructor")
97102

@@ -395,9 +400,23 @@ def log(self) -> Quaternion:
395400
:seealso: :meth:`Quaternion.exp` :meth:`Quaternion.log` :meth:`UnitQuaternion.angvec`
396401
"""
397402
norm = self.norm()
398-
s = math.log(norm)
399-
v = math.acos(np.clip(self.s / norm, -1, 1)) * smb.unitvec(self.v)
400-
return Quaternion(s=s, v=v)
403+
s = np.log(norm)
404+
if len(self) == 1:
405+
if smb.iszerovec(self._A[1:4]):
406+
v = np.zeros((3,))
407+
else:
408+
v = math.acos(np.clip(self._A[0] / norm, -1, 1)) * smb.unitvec(
409+
self._A[1:4]
410+
)
411+
return Quaternion(s=s, v=v)
412+
else:
413+
v = [
414+
np.zeros((3,))
415+
if smb.iszerovec(A[1:4])
416+
else math.acos(np.clip(A[0] / n, -1, 1)) * smb.unitvec(A[1:4])
417+
for A, n in zip(self._A, norm)
418+
]
419+
return Quaternion(s=s, v=np.array(v))
401420

402421
def exp(self, tol: float = 20) -> Quaternion:
403422
r"""
@@ -437,7 +456,11 @@ def exp(self, tol: float = 20) -> Quaternion:
437456
exp_s = math.exp(self.s)
438457
norm_v = smb.norm(self.v)
439458
s = exp_s * math.cos(norm_v)
440-
v = exp_s * self.v / norm_v * math.sin(norm_v)
459+
if smb.iszerovec(self.v, tol * _eps):
460+
# result will be a unit quaternion
461+
v = np.zeros((3,))
462+
else:
463+
v = exp_s * self.v / norm_v * math.sin(norm_v)
441464
if abs(self.s) < tol * _eps:
442465
# result will be a unit quaternion
443466
return UnitQuaternion(s=s, v=v)
@@ -1260,7 +1283,7 @@ def Eul(cls, *angles: List[float], unit: Optional[str] = "rad") -> UnitQuaternio
12601283
Construct a new unit quaternion from Euler angles
12611284
12621285
:param 𝚪: 3-vector of Euler angles
1263-
:type 𝚪: array_like
1286+
:type 𝚪: 3 floats, array_like(3) or ndarray(N,3)
12641287
:param unit: angular units: 'rad' [default], or 'deg'
12651288
:type unit: str
12661289
:return: unit-quaternion
@@ -1286,20 +1309,23 @@ def Eul(cls, *angles: List[float], unit: Optional[str] = "rad") -> UnitQuaternio
12861309
if len(angles) == 1:
12871310
angles = angles[0]
12881311

1289-
return cls(smb.r2q(smb.eul2r(angles, unit=unit)), check=False)
1312+
if smb.isvector(angles, 3):
1313+
return cls(smb.r2q(smb.eul2r(angles, unit=unit)), check=False)
1314+
else:
1315+
return cls([smb.r2q(smb.eul2r(a, unit=unit)) for a in angles], check=False)
12901316

12911317
@classmethod
12921318
def RPY(
12931319
cls,
1294-
*angles: List[float],
1320+
*angles,
12951321
order: Optional[str] = "zyx",
12961322
unit: Optional[str] = "rad",
12971323
) -> UnitQuaternion:
12981324
r"""
12991325
Construct a new unit quaternion from roll-pitch-yaw angles
13001326
13011327
:param 𝚪: 3-vector of roll-pitch-yaw angles
1302-
:type 𝚪: array_like
1328+
:type 𝚪: 3 floats, array_like(3) or ndarray(N,3)
13031329
:param unit: angular units: 'rad' [default], or 'deg'
13041330
:type unit: str
13051331
:param unit: rotation order: 'zyx' [default], 'xyz', or 'yxz'
@@ -1341,7 +1367,13 @@ def RPY(
13411367
if len(angles) == 1:
13421368
angles = angles[0]
13431369

1344-
return cls(smb.r2q(smb.rpy2r(angles, unit=unit, order=order)), check=False)
1370+
if smb.isvector(angles, 3):
1371+
return cls(smb.r2q(smb.rpy2r(angles, unit=unit, order=order)), check=False)
1372+
else:
1373+
return cls(
1374+
[smb.r2q(smb.rpy2r(a, unit=unit, order=order)) for a in angles],
1375+
check=False,
1376+
)
13451377

13461378
@classmethod
13471379
def OA(cls, o: ArrayLike3, a: ArrayLike3) -> UnitQuaternion:
@@ -1569,6 +1601,24 @@ def dotb(self, omega: ArrayLike3) -> R4:
15691601
"""
15701602
return smb.qdotb(self._A, omega)
15711603

1604+
# def mean(self, tol: float = 20) -> SO3:
1605+
# """Mean of a set of rotations
1606+
1607+
# :param tol: iteration tolerance in units of eps, defaults to 20
1608+
# :type tol: float, optional
1609+
# :return: the mean rotation
1610+
# :rtype: :class:`UnitQuaternion` instance.
1611+
1612+
# Computes the Karcher mean of the set of rotations within the unit quaternion instance.
1613+
1614+
# :references:
1615+
# - `**Hartley, Trumpf** - "Rotation Averaging" - IJCV 2011 <https://users.cecs.anu.edu.au/~hartley/Papers/PDF/Hartley-Trumpf:Rotation-averaging:IJCV.pdf>`_
1616+
# - `Karcher mean <https://en.wikipedia.org/wiki/Karcher_mean`_
1617+
# """
1618+
1619+
# R_mean = self.SO3().mean(tol=tol)
1620+
# return R_mean.UnitQuaternion()
1621+
15721622
def __mul__(
15731623
left, right: UnitQuaternion
15741624
) -> UnitQuaternion: # lgtm[py/not-named-self] pylint: disable=no-self-argument

‎tests/test_pose3d.py

Copy file name to clipboardExpand all lines: tests/test_pose3d.py
+31Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,37 @@ def test_functions_lie(self):
717717
nt.assert_equal(R, SO3.EulerVec(R.eulervec()))
718718
np.testing.assert_equal((R.inv() * R).eulervec(), np.zeros(3))
719719

720+
R = SO3() # identity matrix case
721+
722+
# Check log and exponential map
723+
nt.assert_equal(R, SO3.Exp(R.log()))
724+
np.testing.assert_equal((R.inv() * R).log(), np.zeros([3, 3]))
725+
726+
# Check euler vector map
727+
nt.assert_equal(R, SO3.EulerVec(R.eulervec()))
728+
np.testing.assert_equal((R.inv() * R).eulervec(), np.zeros(3))
729+
730+
def test_mean(self):
731+
rpy = np.ones((100, 1)) @ np.c_[0.1, 0.2, 0.3]
732+
R = SO3.RPY(rpy)
733+
self.assertEqual(len(R), 100)
734+
m = R.mean()
735+
self.assertIsInstance(m, SO3)
736+
array_compare(m, R[0])
737+
738+
# range of angles, mean should be the middle one, index=25
739+
R = SO3.Rz(np.linspace(start=0.3, stop=0.7, num=51))
740+
m = R.mean()
741+
self.assertIsInstance(m, SO3)
742+
array_compare(m, R[25])
743+
744+
# now add noise
745+
rng = np.random.default_rng(0) # reproducible random numbers
746+
rpy += rng.normal(scale=0.00001, size=(100, 3))
747+
R = SO3.RPY(rpy)
748+
m = R.mean()
749+
array_compare(m, SO3.RPY(0.1, 0.2, 0.3))
750+
720751

721752
# ============================== SE3 =====================================#
722753

‎tests/test_quaternion.py

Copy file name to clipboardExpand all lines: tests/test_quaternion.py
+95-5Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,25 +257,79 @@ def test_staticconstructors(self):
257257
UnitQuaternion.Rz(theta, "deg").R, rotz(theta, "deg")
258258
)
259259

260+
def test_constructor_RPY(self):
260261
# 3 angle
262+
q = UnitQuaternion.RPY([0.1, 0.2, 0.3])
263+
self.assertIsInstance(q, UnitQuaternion)
264+
self.assertEqual(len(q), 1)
265+
nt.assert_array_almost_equal(q.R, rpy2r(0.1, 0.2, 0.3))
266+
q = UnitQuaternion.RPY(0.1, 0.2, 0.3)
267+
self.assertIsInstance(q, UnitQuaternion)
268+
self.assertEqual(len(q), 1)
269+
nt.assert_array_almost_equal(q.R, rpy2r(0.1, 0.2, 0.3))
270+
q = UnitQuaternion.RPY(np.r_[0.1, 0.2, 0.3])
271+
self.assertIsInstance(q, UnitQuaternion)
272+
self.assertEqual(len(q), 1)
273+
nt.assert_array_almost_equal(q.R, rpy2r(0.1, 0.2, 0.3))
274+
261275
nt.assert_array_almost_equal(
262-
UnitQuaternion.RPY([0.1, 0.2, 0.3]).R, rpy2r(0.1, 0.2, 0.3)
276+
UnitQuaternion.RPY([10, 20, 30], unit="deg").R,
277+
rpy2r(10, 20, 30, unit="deg"),
263278
)
264-
265279
nt.assert_array_almost_equal(
266-
UnitQuaternion.Eul([0.1, 0.2, 0.3]).R, eul2r(0.1, 0.2, 0.3)
280+
UnitQuaternion.RPY([0.1, 0.2, 0.3], order="xyz").R,
281+
rpy2r(0.1, 0.2, 0.3, order="xyz"),
282+
)
283+
284+
angles = np.array(
285+
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]]
267286
)
287+
q = UnitQuaternion.RPY(angles)
288+
self.assertIsInstance(q, UnitQuaternion)
289+
self.assertEqual(len(q), 4)
290+
for i in range(4):
291+
nt.assert_array_almost_equal(q[i].R, rpy2r(angles[i, :]))
292+
293+
q = UnitQuaternion.RPY(angles, order="xyz")
294+
self.assertIsInstance(q, UnitQuaternion)
295+
self.assertEqual(len(q), 4)
296+
for i in range(4):
297+
nt.assert_array_almost_equal(q[i].R, rpy2r(angles[i, :], order="xyz"))
268298

299+
angles *= 10
300+
q = UnitQuaternion.RPY(angles, unit="deg")
301+
self.assertIsInstance(q, UnitQuaternion)
302+
self.assertEqual(len(q), 4)
303+
for i in range(4):
304+
nt.assert_array_almost_equal(q[i].R, rpy2r(angles[i, :], unit="deg"))
305+
306+
def test_constructor_Eul(self):
269307
nt.assert_array_almost_equal(
270-
UnitQuaternion.RPY([10, 20, 30], unit="deg").R,
271-
rpy2r(10, 20, 30, unit="deg"),
308+
UnitQuaternion.Eul([0.1, 0.2, 0.3]).R, eul2r(0.1, 0.2, 0.3)
272309
)
273310

274311
nt.assert_array_almost_equal(
275312
UnitQuaternion.Eul([10, 20, 30], unit="deg").R,
276313
eul2r(10, 20, 30, unit="deg"),
277314
)
278315

316+
angles = np.array(
317+
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]]
318+
)
319+
q = UnitQuaternion.Eul(angles)
320+
self.assertIsInstance(q, UnitQuaternion)
321+
self.assertEqual(len(q), 4)
322+
for i in range(4):
323+
nt.assert_array_almost_equal(q[i].R, eul2r(angles[i, :]))
324+
325+
angles *= 10
326+
q = UnitQuaternion.Eul(angles, unit="deg")
327+
self.assertIsInstance(q, UnitQuaternion)
328+
self.assertEqual(len(q), 4)
329+
for i in range(4):
330+
nt.assert_array_almost_equal(q[i].R, eul2r(angles[i, :], unit="deg"))
331+
332+
def test_constructor_AngVec(self):
279333
# (theta, v)
280334
th = 0.2
281335
v = unitvec([1, 2, 3])
@@ -286,6 +340,7 @@ def test_staticconstructors(self):
286340
)
287341
nt.assert_array_almost_equal(UnitQuaternion.AngVec(th, -v).R, angvec2r(th, -v))
288342

343+
def test_constructor_EulerVec(self):
289344
# (theta, v)
290345
th = 0.2
291346
v = unitvec([1, 2, 3])
@@ -830,6 +885,20 @@ def test_log(self):
830885
nt.assert_array_almost_equal(q1.log().exp(), q1)
831886
nt.assert_array_almost_equal(q2.log().exp(), q2)
832887

888+
q = Quaternion([q1, q2, q1, q2])
889+
qlog = q.log()
890+
nt.assert_array_almost_equal(qlog[0].exp(), q1)
891+
nt.assert_array_almost_equal(qlog[1].exp(), q2)
892+
nt.assert_array_almost_equal(qlog[2].exp(), q1)
893+
nt.assert_array_almost_equal(qlog[3].exp(), q2)
894+
895+
q = UnitQuaternion() # identity
896+
qlog = q.log()
897+
nt.assert_array_almost_equal(qlog.vec, np.zeros(4))
898+
qq = qlog.exp()
899+
self.assertIsInstance(qq, UnitQuaternion)
900+
nt.assert_array_almost_equal(qq.vec, np.r_[1, 0, 0, 0])
901+
833902
def test_concat(self):
834903
u = Quaternion()
835904
uu = Quaternion([u, u, u, u])
@@ -1018,6 +1087,27 @@ def test_miscellany(self):
10181087
nt.assert_equal(q.inner(q), q.norm() ** 2)
10191088
nt.assert_equal(q.inner(u), np.dot(q.vec, u.vec))
10201089

1090+
# def test_mean(self):
1091+
# rpy = np.ones((100, 1)) @ np.c_[0.1, 0.2, 0.3]
1092+
# q = UnitQuaternion.RPY(rpy)
1093+
# self.assertEqual(len(q), 100)
1094+
# m = q.mean()
1095+
# self.assertIsInstance(m, UnitQuaternion)
1096+
# nt.assert_array_almost_equal(m.vec, q[0].vec)
1097+
1098+
# # range of angles, mean should be the middle one, index=25
1099+
# q = UnitQuaternion.Rz(np.linspace(start=0.3, stop=0.7, num=51))
1100+
# m = q.mean()
1101+
# self.assertIsInstance(m, UnitQuaternion)
1102+
# nt.assert_array_almost_equal(m.vec, q[25].vec)
1103+
1104+
# # now add noise
1105+
# rng = np.random.default_rng(0) # reproducible random numbers
1106+
# rpy += rng.normal(scale=0.1, size=(100, 3))
1107+
# q = UnitQuaternion.RPY(rpy)
1108+
# m = q.mean()
1109+
# nt.assert_array_almost_equal(m.vec, q.RPY(0.1, 0.2, 0.3).vec)
1110+
10211111

10221112
# ---------------------------------------------------------------------------------------#
10231113
if __name__ == "__main__":

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.