From 338ba200f564278105cd79deac3e4966092dad23 Mon Sep 17 00:00:00 2001 From: Peter Corke Date: Sun, 19 Jan 2025 22:23:11 +1000 Subject: [PATCH] Fix for issue #144. The `dim` option of `getunit()` was poorly defined and inconsistently used. Added a `vector` option to enforce a vector/scalar return, and `dim` is only used for checking. Fixed tests and all calls to `getunit()` where this case occurred. Also fixed some minor sphinx doco formatting problems. --- spatialmath/base/argcheck.py | 50 ++++++++++++++++++++++---------- spatialmath/base/transforms2d.py | 2 +- spatialmath/base/transforms3d.py | 35 +++++++++++++--------- spatialmath/base/vectors.py | 18 ++++++++---- spatialmath/quaternion.py | 2 +- tests/base/test_argcheck.py | 42 +++++++++++++++++++++++++-- 6 files changed, 109 insertions(+), 40 deletions(-) diff --git a/spatialmath/base/argcheck.py b/spatialmath/base/argcheck.py index 40f94336..38b5eb1a 100644 --- a/spatialmath/base/argcheck.py +++ b/spatialmath/base/argcheck.py @@ -522,7 +522,9 @@ def isvector(v: Any, dim: Optional[int] = None) -> bool: return False -def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> Union[float, NDArray]: +def getunit( + v: ArrayLike, unit: str = "rad", dim: Optional[int] = None, vector: bool = True +) -> Union[float, NDArray]: """ Convert values according to angular units @@ -530,8 +532,10 @@ def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> Union[float, NDArray]: :type v: array_like(m) :param unit: the angular unit, "rad" or "deg" :type unit: str - :param dim: expected dimension of input, defaults to None + :param dim: expected dimension of input, defaults to don't check (None) :type dim: int, optional + :param vector: return a scalar as a 1d vector, defaults to True + :type vector: bool, optional :return: the converted value in radians :rtype: ndarray(m) or float :raises ValueError: argument is not a valid angular unit @@ -543,30 +547,44 @@ def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> Union[float, NDArray]: >>> from spatialmath.base import getunit >>> import numpy as np >>> getunit(1.5, 'rad') - >>> getunit(1.5, 'rad', dim=0) - >>> # getunit([1.5], 'rad', dim=0) --> ValueError >>> getunit(90, 'deg') + >>> getunit(90, 'deg', vector=False) # force a scalar output + >>> getunit(1.5, 'rad', dim=0) # check argument is scalar + >>> getunit(1.5, 'rad', dim=3) # check argument is a 3-vector + >>> getunit([1.5], 'rad', dim=1) # check argument is a 1-vector + >>> getunit([1.5], 'rad', dim=3) # check argument is a 3-vector >>> getunit([90, 180], 'deg') - >>> getunit(np.r_[0.5, 1], 'rad') >>> getunit(np.r_[90, 180], 'deg') - >>> getunit(np.r_[90, 180], 'deg', dim=2) - >>> # getunit([90, 180], 'deg', dim=3) --> ValueError + >>> getunit(np.r_[90, 180], 'deg', dim=2) # check argument is a 2-vector + >>> getunit([90, 180], 'deg', dim=3) # check argument is a 3-vector :note: - the input value is processed by :func:`getvector` and the argument ``dim`` can - be used to check that ``v`` is the desired length. - - the output is always an ndarray except if the input is a scalar and ``dim=0``. + be used to check that ``v`` is the desired length. Note that 0 means a scalar, + whereas 1 means a 1-element array. + - the output is always an ndarray except if the input is a scalar and ``vector=False``. :seealso: :func:`getvector` """ - if not isinstance(v, Iterable) and dim == 0: - # scalar in, scalar out - if unit == "rad": - return v - elif unit == "deg": - return np.deg2rad(v) + if not isinstance(v, Iterable): + # scalar input + if dim is not None and dim != 0: + raise ValueError("for dim==0 input must be a scalar") + if vector: + # scalar in, vector out + if unit == "deg": + v = np.deg2rad(v) + elif unit != "rad": + raise ValueError("invalid angular units") + return np.array([v]) else: - raise ValueError("invalid angular units") + # scalar in, scalar out + if unit == "rad": + return v + elif unit == "deg": + return np.deg2rad(v) + else: + raise ValueError("invalid angular units") else: # scalar or iterable in, ndarray out diff --git a/spatialmath/base/transforms2d.py b/spatialmath/base/transforms2d.py index 682ea0ca..25265fff 100644 --- a/spatialmath/base/transforms2d.py +++ b/spatialmath/base/transforms2d.py @@ -63,7 +63,7 @@ def rot2(theta: float, unit: str = "rad") -> SO2Array: >>> rot2(0.3) >>> rot2(45, 'deg') """ - theta = smb.getunit(theta, unit, dim=0) + theta = smb.getunit(theta, unit, vector=False) ct = smb.sym.cos(theta) st = smb.sym.sin(theta) # fmt: off diff --git a/spatialmath/base/transforms3d.py b/spatialmath/base/transforms3d.py index 3617f965..350e1d14 100644 --- a/spatialmath/base/transforms3d.py +++ b/spatialmath/base/transforms3d.py @@ -79,7 +79,7 @@ def rotx(theta: float, unit: str = "rad") -> SO3Array: :SymPy: supported """ - theta = getunit(theta, unit, dim=0) + theta = getunit(theta, unit, vector=False) ct = sym.cos(theta) st = sym.sin(theta) # fmt: off @@ -118,7 +118,7 @@ def roty(theta: float, unit: str = "rad") -> SO3Array: :SymPy: supported """ - theta = getunit(theta, unit, dim=0) + theta = getunit(theta, unit, vector=False) ct = sym.cos(theta) st = sym.sin(theta) # fmt: off @@ -152,7 +152,7 @@ def rotz(theta: float, unit: str = "rad") -> SO3Array: :seealso: :func:`~trotz` :SymPy: supported """ - theta = getunit(theta, unit, dim=0) + theta = getunit(theta, unit, vector=False) ct = sym.cos(theta) st = sym.sin(theta) # fmt: off @@ -2709,7 +2709,7 @@ def tr2adjoint(T): :Reference: - Robotics, Vision & Control for Python, Section 3, P. Corke, Springer 2023. - - `Lie groups for 2D and 3D Transformations _ + - `Lie groups for 2D and 3D Transformations `_ :SymPy: supported """ @@ -3002,29 +3002,36 @@ def trplot( - ``width`` of line - ``length`` of line - ``style`` which is one of: + - ``'arrow'`` [default], draw line with arrow head in ``color`` - ``'line'``, draw line with no arrow head in ``color`` - ``'rgb'``, frame axes are lines with no arrow head and red for X, green - for Y, blue for Z; no origin dot + for Y, blue for Z; no origin dot - ``'rviz'``, frame axes are thick lines with no arrow head and red for X, - green for Y, blue for Z; no origin dot + green for Y, blue for Z; no origin dot + - coordinate axis labels depend on: + - ``axislabel`` if True [default] label the axis, default labels are X, Y, Z - ``labels`` 3-list of alternative axis labels - ``textcolor`` which defaults to ``color`` - ``axissubscript`` if True [default] add the frame label ``frame`` as a subscript - for each axis label + for each axis label + - coordinate frame label depends on: + - `frame` the label placed inside {} near the origin of the frame + - a dot at the origin + - ``originsize`` size of the dot, if zero no dot - ``origincolor`` color of the dot, defaults to ``color`` Examples:: - trplot(T, frame='A') - trplot(T, frame='A', color='green') - trplot(T1, 'labels', 'UVW'); + trplot(T, frame='A') + trplot(T, frame='A', color='green') + trplot(T1, 'labels', 'UVW'); .. plot:: @@ -3383,12 +3390,12 @@ def tranimate(T: Union[SO3Array, SE3Array], **kwargs) -> str: :param **kwargs: arguments passed to ``trplot`` - ``tranimate(T)`` where ``T`` is an SO(3) or SE(3) matrix, animates a 3D - coordinate frame moving from the world frame to the frame ``T`` in - ``nsteps``. + coordinate frame moving from the world frame to the frame ``T`` in + ``nsteps``. - ``tranimate(I)`` where ``I`` is an iterable or generator, animates a 3D - coordinate frame representing the pose of each element in the sequence of - SO(3) or SE(3) matrices. + coordinate frame representing the pose of each element in the sequence of + SO(3) or SE(3) matrices. Examples: diff --git a/spatialmath/base/vectors.py b/spatialmath/base/vectors.py index f29740a3..bf95283f 100644 --- a/spatialmath/base/vectors.py +++ b/spatialmath/base/vectors.py @@ -530,6 +530,7 @@ def wrap_0_pi(theta: ArrayLike) -> Union[float, NDArray]: :param theta: input angle :type theta: scalar or ndarray :return: angle wrapped into range :math:`[0, \pi)` + :rtype: scalar or ndarray This is used to fold angles of colatitude. If zero is the angle of the north pole, colatitude increases to :math:`\pi` at the south pole then @@ -537,7 +538,7 @@ def wrap_0_pi(theta: ArrayLike) -> Union[float, NDArray]: :seealso: :func:`wrap_mpi2_pi2` :func:`wrap_0_2pi` :func:`wrap_mpi_pi` :func:`angle_wrap` """ - theta = np.abs(theta) + theta = np.abs(getvector(theta)) n = theta / np.pi if isinstance(n, np.ndarray): n = n.astype(int) @@ -546,7 +547,7 @@ def wrap_0_pi(theta: ArrayLike) -> Union[float, NDArray]: y = np.where(np.bitwise_and(n, 1) == 0, theta - n * np.pi, (n + 1) * np.pi - theta) if isinstance(y, np.ndarray) and y.size == 1: - return float(y) + return float(y[0]) else: return y @@ -558,6 +559,7 @@ def wrap_mpi2_pi2(theta: ArrayLike) -> Union[float, NDArray]: :param theta: input angle :type theta: scalar or ndarray :return: angle wrapped into range :math:`[-\pi/2, \pi/2]` + :rtype: scalar or ndarray This is used to fold angles of latitude. @@ -573,7 +575,7 @@ def wrap_mpi2_pi2(theta: ArrayLike) -> Union[float, NDArray]: y = np.where(np.bitwise_and(n, 1) == 0, theta - n * np.pi, n * np.pi - theta) if isinstance(y, np.ndarray) and len(y) == 1: - return float(y) + return float(y[0]) else: return y @@ -585,13 +587,14 @@ def wrap_0_2pi(theta: ArrayLike) -> Union[float, NDArray]: :param theta: input angle :type theta: scalar or ndarray :return: angle wrapped into range :math:`[0, 2\pi)` + :rtype: scalar or ndarray :seealso: :func:`wrap_mpi_pi` :func:`wrap_0_pi` :func:`wrap_mpi2_pi2` :func:`angle_wrap` """ theta = getvector(theta) y = theta - 2.0 * math.pi * np.floor(theta / 2.0 / np.pi) if isinstance(y, np.ndarray) and len(y) == 1: - return float(y) + return float(y[0]) else: return y @@ -603,13 +606,14 @@ def wrap_mpi_pi(theta: ArrayLike) -> Union[float, NDArray]: :param theta: input angle :type theta: scalar or ndarray :return: angle wrapped into range :math:`[-\pi, \pi)` + :rtype: scalar or ndarray :seealso: :func:`wrap_0_2pi` :func:`wrap_0_pi` :func:`wrap_mpi2_pi2` :func:`angle_wrap` """ theta = getvector(theta) y = np.mod(theta + math.pi, 2 * math.pi) - np.pi if isinstance(y, np.ndarray) and len(y) == 1: - return float(y) + return float(y[0]) else: return y @@ -643,6 +647,7 @@ def angdiff(a, b=None): - ``angdiff(a, b)`` is the difference ``a - b`` wrapped to the range :math:`[-\pi, \pi)`. This is the operator :math:`a \circleddash b` used in the RVC book + - If ``a`` and ``b`` are both scalars, the result is scalar - If ``a`` is array_like, the result is a NumPy array ``a[i]-b`` - If ``a`` is array_like, the result is a NumPy array ``a-b[i]`` @@ -651,6 +656,7 @@ def angdiff(a, b=None): - ``angdiff(a)`` is the angle or vector of angles ``a`` wrapped to the range :math:`[-\pi, \pi)`. + - If ``a`` is a scalar, the result is scalar - If ``a`` is array_like, the result is a NumPy array @@ -671,7 +677,7 @@ def angdiff(a, b=None): y = np.mod(a + math.pi, 2 * math.pi) - math.pi if isinstance(y, np.ndarray) and len(y) == 1: - return float(y) + return float(y[0]) else: return y diff --git a/spatialmath/quaternion.py b/spatialmath/quaternion.py index 51561036..e77d12c8 100644 --- a/spatialmath/quaternion.py +++ b/spatialmath/quaternion.py @@ -1411,7 +1411,7 @@ def AngVec( :seealso: :meth:`UnitQuaternion.angvec` :meth:`UnitQuaternion.exp` :func:`~spatialmath.base.transforms3d.angvec2r` """ v = smb.getvector(v, 3) - theta = smb.getunit(theta, unit, dim=0) + theta = smb.getunit(theta, unit, vector=False) return cls( s=math.cos(theta / 2), v=math.sin(theta / 2) * v, norm=False, check=False ) diff --git a/tests/base/test_argcheck.py b/tests/base/test_argcheck.py index 685393b5..39c943d1 100755 --- a/tests/base/test_argcheck.py +++ b/tests/base/test_argcheck.py @@ -122,11 +122,49 @@ def test_verifymatrix(self): verifymatrix(a, (3, 4)) def test_unit(self): - self.assertIsInstance(getunit(1), np.ndarray) + # scalar -> vector + self.assertEqual(getunit(1), np.array([1])) + self.assertEqual(getunit(1, dim=0), np.array([1])) + with self.assertRaises(ValueError): + self.assertEqual(getunit(1, dim=1), np.array([1])) + + self.assertEqual(getunit(1, unit="deg"), np.array([1 * math.pi / 180.0])) + self.assertEqual(getunit(1, dim=0, unit="deg"), np.array([1 * math.pi / 180.0])) + with self.assertRaises(ValueError): + self.assertEqual( + getunit(1, dim=1, unit="deg"), np.array([1 * math.pi / 180.0]) + ) + + # scalar -> scalar + self.assertEqual(getunit(1, vector=False), 1) + self.assertEqual(getunit(1, dim=0, vector=False), 1) + with self.assertRaises(ValueError): + self.assertEqual(getunit(1, dim=1, vector=False), 1) + + self.assertIsInstance(getunit(1.0, vector=False), float) + self.assertIsInstance(getunit(1, vector=False), int) + + self.assertEqual(getunit(1, vector=False, unit="deg"), 1 * math.pi / 180.0) + self.assertEqual( + getunit(1, dim=0, vector=False, unit="deg"), 1 * math.pi / 180.0 + ) + with self.assertRaises(ValueError): + self.assertEqual( + getunit(1, dim=1, vector=False, unit="deg"), 1 * math.pi / 180.0 + ) + + self.assertIsInstance(getunit(1.0, vector=False, unit="deg"), float) + self.assertIsInstance(getunit(1, vector=False, unit="deg"), float) + + # vector -> vector + self.assertEqual(getunit([1]), np.array([1])) + self.assertEqual(getunit([1], dim=1), np.array([1])) + with self.assertRaises(ValueError): + getunit([1], dim=0) + self.assertIsInstance(getunit([1, 2]), np.ndarray) self.assertIsInstance(getunit((1, 2)), np.ndarray) self.assertIsInstance(getunit(np.r_[1, 2]), np.ndarray) - self.assertIsInstance(getunit(1.0, dim=0), float) nt.assert_equal(getunit(5, "rad"), 5) nt.assert_equal(getunit(5, "deg"), 5 * math.pi / 180.0)