Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 39d6127

Browse filesBrowse files
authored
Fix for issue #144. (#155)
1 parent 659ba24 commit 39d6127
Copy full SHA for 39d6127

File tree

Expand file treeCollapse file tree

6 files changed

+109
-40
lines changed
Filter options
Expand file treeCollapse file tree

6 files changed

+109
-40
lines changed

‎spatialmath/base/argcheck.py

Copy file name to clipboardExpand all lines: spatialmath/base/argcheck.py
+34-16Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -522,16 +522,20 @@ def isvector(v: Any, dim: Optional[int] = None) -> bool:
522522
return False
523523

524524

525-
def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> Union[float, NDArray]:
525+
def getunit(
526+
v: ArrayLike, unit: str = "rad", dim: Optional[int] = None, vector: bool = True
527+
) -> Union[float, NDArray]:
526528
"""
527529
Convert values according to angular units
528530
529531
:param v: the value in radians or degrees
530532
:type v: array_like(m)
531533
:param unit: the angular unit, "rad" or "deg"
532534
:type unit: str
533-
:param dim: expected dimension of input, defaults to None
535+
:param dim: expected dimension of input, defaults to don't check (None)
534536
:type dim: int, optional
537+
:param vector: return a scalar as a 1d vector, defaults to True
538+
:type vector: bool, optional
535539
:return: the converted value in radians
536540
:rtype: ndarray(m) or float
537541
:raises ValueError: argument is not a valid angular unit
@@ -543,30 +547,44 @@ def getunit(v: ArrayLike, unit: str = "rad", dim=None) -> Union[float, NDArray]:
543547
>>> from spatialmath.base import getunit
544548
>>> import numpy as np
545549
>>> getunit(1.5, 'rad')
546-
>>> getunit(1.5, 'rad', dim=0)
547-
>>> # getunit([1.5], 'rad', dim=0) --> ValueError
548550
>>> getunit(90, 'deg')
551+
>>> getunit(90, 'deg', vector=False) # force a scalar output
552+
>>> getunit(1.5, 'rad', dim=0) # check argument is scalar
553+
>>> getunit(1.5, 'rad', dim=3) # check argument is a 3-vector
554+
>>> getunit([1.5], 'rad', dim=1) # check argument is a 1-vector
555+
>>> getunit([1.5], 'rad', dim=3) # check argument is a 3-vector
549556
>>> getunit([90, 180], 'deg')
550-
>>> getunit(np.r_[0.5, 1], 'rad')
551557
>>> getunit(np.r_[90, 180], 'deg')
552-
>>> getunit(np.r_[90, 180], 'deg', dim=2)
553-
>>> # getunit([90, 180], 'deg', dim=3) --> ValueError
558+
>>> getunit(np.r_[90, 180], 'deg', dim=2) # check argument is a 2-vector
559+
>>> getunit([90, 180], 'deg', dim=3) # check argument is a 3-vector
554560
555561
:note:
556562
- the input value is processed by :func:`getvector` and the argument ``dim`` can
557-
be used to check that ``v`` is the desired length.
558-
- the output is always an ndarray except if the input is a scalar and ``dim=0``.
563+
be used to check that ``v`` is the desired length. Note that 0 means a scalar,
564+
whereas 1 means a 1-element array.
565+
- the output is always an ndarray except if the input is a scalar and ``vector=False``.
559566
560567
:seealso: :func:`getvector`
561568
"""
562-
if not isinstance(v, Iterable) and dim == 0:
563-
# scalar in, scalar out
564-
if unit == "rad":
565-
return v
566-
elif unit == "deg":
567-
return np.deg2rad(v)
569+
if not isinstance(v, Iterable):
570+
# scalar input
571+
if dim is not None and dim != 0:
572+
raise ValueError("for dim==0 input must be a scalar")
573+
if vector:
574+
# scalar in, vector out
575+
if unit == "deg":
576+
v = np.deg2rad(v)
577+
elif unit != "rad":
578+
raise ValueError("invalid angular units")
579+
return np.array([v])
568580
else:
569-
raise ValueError("invalid angular units")
581+
# scalar in, scalar out
582+
if unit == "rad":
583+
return v
584+
elif unit == "deg":
585+
return np.deg2rad(v)
586+
else:
587+
raise ValueError("invalid angular units")
570588

571589
else:
572590
# scalar or iterable in, ndarray out

‎spatialmath/base/transforms2d.py

Copy file name to clipboardExpand all lines: spatialmath/base/transforms2d.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def rot2(theta: float, unit: str = "rad") -> SO2Array:
6363
>>> rot2(0.3)
6464
>>> rot2(45, 'deg')
6565
"""
66-
theta = smb.getunit(theta, unit, dim=0)
66+
theta = smb.getunit(theta, unit, vector=False)
6767
ct = smb.sym.cos(theta)
6868
st = smb.sym.sin(theta)
6969
# fmt: off

‎spatialmath/base/transforms3d.py

Copy file name to clipboardExpand all lines: spatialmath/base/transforms3d.py
+21-14Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def rotx(theta: float, unit: str = "rad") -> SO3Array:
7979
:SymPy: supported
8080
"""
8181

82-
theta = getunit(theta, unit, dim=0)
82+
theta = getunit(theta, unit, vector=False)
8383
ct = sym.cos(theta)
8484
st = sym.sin(theta)
8585
# fmt: off
@@ -118,7 +118,7 @@ def roty(theta: float, unit: str = "rad") -> SO3Array:
118118
:SymPy: supported
119119
"""
120120

121-
theta = getunit(theta, unit, dim=0)
121+
theta = getunit(theta, unit, vector=False)
122122
ct = sym.cos(theta)
123123
st = sym.sin(theta)
124124
# fmt: off
@@ -152,7 +152,7 @@ def rotz(theta: float, unit: str = "rad") -> SO3Array:
152152
:seealso: :func:`~trotz`
153153
:SymPy: supported
154154
"""
155-
theta = getunit(theta, unit, dim=0)
155+
theta = getunit(theta, unit, vector=False)
156156
ct = sym.cos(theta)
157157
st = sym.sin(theta)
158158
# fmt: off
@@ -2709,7 +2709,7 @@ def tr2adjoint(T):
27092709
27102710
:Reference:
27112711
- Robotics, Vision & Control for Python, Section 3, P. Corke, Springer 2023.
2712-
- `Lie groups for 2D and 3D Transformations <http://ethaneade.com/lie.pdf>_
2712+
- `Lie groups for 2D and 3D Transformations <http://ethaneade.com/lie.pdf>`_
27132713
27142714
:SymPy: supported
27152715
"""
@@ -3002,29 +3002,36 @@ def trplot(
30023002
- ``width`` of line
30033003
- ``length`` of line
30043004
- ``style`` which is one of:
3005+
30053006
- ``'arrow'`` [default], draw line with arrow head in ``color``
30063007
- ``'line'``, draw line with no arrow head in ``color``
30073008
- ``'rgb'``, frame axes are lines with no arrow head and red for X, green
3008-
for Y, blue for Z; no origin dot
3009+
for Y, blue for Z; no origin dot
30093010
- ``'rviz'``, frame axes are thick lines with no arrow head and red for X,
3010-
green for Y, blue for Z; no origin dot
3011+
green for Y, blue for Z; no origin dot
3012+
30113013
- coordinate axis labels depend on:
3014+
30123015
- ``axislabel`` if True [default] label the axis, default labels are X, Y, Z
30133016
- ``labels`` 3-list of alternative axis labels
30143017
- ``textcolor`` which defaults to ``color``
30153018
- ``axissubscript`` if True [default] add the frame label ``frame`` as a subscript
3016-
for each axis label
3019+
for each axis label
3020+
30173021
- coordinate frame label depends on:
3022+
30183023
- `frame` the label placed inside {} near the origin of the frame
3024+
30193025
- a dot at the origin
3026+
30203027
- ``originsize`` size of the dot, if zero no dot
30213028
- ``origincolor`` color of the dot, defaults to ``color``
30223029
30233030
Examples::
30243031
3025-
trplot(T, frame='A')
3026-
trplot(T, frame='A', color='green')
3027-
trplot(T1, 'labels', 'UVW');
3032+
trplot(T, frame='A')
3033+
trplot(T, frame='A', color='green')
3034+
trplot(T1, 'labels', 'UVW');
30283035
30293036
.. plot::
30303037
@@ -3383,12 +3390,12 @@ def tranimate(T: Union[SO3Array, SE3Array], **kwargs) -> str:
33833390
:param **kwargs: arguments passed to ``trplot``
33843391
33853392
- ``tranimate(T)`` where ``T`` is an SO(3) or SE(3) matrix, animates a 3D
3386-
coordinate frame moving from the world frame to the frame ``T`` in
3387-
``nsteps``.
3393+
coordinate frame moving from the world frame to the frame ``T`` in
3394+
``nsteps``.
33883395
33893396
- ``tranimate(I)`` where ``I`` is an iterable or generator, animates a 3D
3390-
coordinate frame representing the pose of each element in the sequence of
3391-
SO(3) or SE(3) matrices.
3397+
coordinate frame representing the pose of each element in the sequence of
3398+
SO(3) or SE(3) matrices.
33923399
33933400
Examples:
33943401

‎spatialmath/base/vectors.py

Copy file name to clipboardExpand all lines: spatialmath/base/vectors.py
+12-6Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -530,14 +530,15 @@ def wrap_0_pi(theta: ArrayLike) -> Union[float, NDArray]:
530530
:param theta: input angle
531531
:type theta: scalar or ndarray
532532
:return: angle wrapped into range :math:`[0, \pi)`
533+
:rtype: scalar or ndarray
533534
534535
This is used to fold angles of colatitude. If zero is the angle of the
535536
north pole, colatitude increases to :math:`\pi` at the south pole then
536537
decreases to :math:`0` as we head back to the north pole.
537538
538539
:seealso: :func:`wrap_mpi2_pi2` :func:`wrap_0_2pi` :func:`wrap_mpi_pi` :func:`angle_wrap`
539540
"""
540-
theta = np.abs(theta)
541+
theta = np.abs(getvector(theta))
541542
n = theta / np.pi
542543
if isinstance(n, np.ndarray):
543544
n = n.astype(int)
@@ -546,7 +547,7 @@ def wrap_0_pi(theta: ArrayLike) -> Union[float, NDArray]:
546547

547548
y = np.where(np.bitwise_and(n, 1) == 0, theta - n * np.pi, (n + 1) * np.pi - theta)
548549
if isinstance(y, np.ndarray) and y.size == 1:
549-
return float(y)
550+
return float(y[0])
550551
else:
551552
return y
552553

@@ -558,6 +559,7 @@ def wrap_mpi2_pi2(theta: ArrayLike) -> Union[float, NDArray]:
558559
:param theta: input angle
559560
:type theta: scalar or ndarray
560561
:return: angle wrapped into range :math:`[-\pi/2, \pi/2]`
562+
:rtype: scalar or ndarray
561563
562564
This is used to fold angles of latitude.
563565
@@ -573,7 +575,7 @@ def wrap_mpi2_pi2(theta: ArrayLike) -> Union[float, NDArray]:
573575

574576
y = np.where(np.bitwise_and(n, 1) == 0, theta - n * np.pi, n * np.pi - theta)
575577
if isinstance(y, np.ndarray) and len(y) == 1:
576-
return float(y)
578+
return float(y[0])
577579
else:
578580
return y
579581

@@ -585,13 +587,14 @@ def wrap_0_2pi(theta: ArrayLike) -> Union[float, NDArray]:
585587
:param theta: input angle
586588
:type theta: scalar or ndarray
587589
:return: angle wrapped into range :math:`[0, 2\pi)`
590+
:rtype: scalar or ndarray
588591
589592
:seealso: :func:`wrap_mpi_pi` :func:`wrap_0_pi` :func:`wrap_mpi2_pi2` :func:`angle_wrap`
590593
"""
591594
theta = getvector(theta)
592595
y = theta - 2.0 * math.pi * np.floor(theta / 2.0 / np.pi)
593596
if isinstance(y, np.ndarray) and len(y) == 1:
594-
return float(y)
597+
return float(y[0])
595598
else:
596599
return y
597600

@@ -603,13 +606,14 @@ def wrap_mpi_pi(theta: ArrayLike) -> Union[float, NDArray]:
603606
:param theta: input angle
604607
:type theta: scalar or ndarray
605608
:return: angle wrapped into range :math:`[-\pi, \pi)`
609+
:rtype: scalar or ndarray
606610
607611
:seealso: :func:`wrap_0_2pi` :func:`wrap_0_pi` :func:`wrap_mpi2_pi2` :func:`angle_wrap`
608612
"""
609613
theta = getvector(theta)
610614
y = np.mod(theta + math.pi, 2 * math.pi) - np.pi
611615
if isinstance(y, np.ndarray) and len(y) == 1:
612-
return float(y)
616+
return float(y[0])
613617
else:
614618
return y
615619

@@ -643,6 +647,7 @@ def angdiff(a, b=None):
643647
- ``angdiff(a, b)`` is the difference ``a - b`` wrapped to the range
644648
:math:`[-\pi, \pi)`. This is the operator :math:`a \circleddash b` used
645649
in the RVC book
650+
646651
- If ``a`` and ``b`` are both scalars, the result is scalar
647652
- If ``a`` is array_like, the result is a NumPy array ``a[i]-b``
648653
- If ``a`` is array_like, the result is a NumPy array ``a-b[i]``
@@ -651,6 +656,7 @@ def angdiff(a, b=None):
651656
652657
- ``angdiff(a)`` is the angle or vector of angles ``a`` wrapped to the range
653658
:math:`[-\pi, \pi)`.
659+
654660
- If ``a`` is a scalar, the result is scalar
655661
- If ``a`` is array_like, the result is a NumPy array
656662
@@ -671,7 +677,7 @@ def angdiff(a, b=None):
671677

672678
y = np.mod(a + math.pi, 2 * math.pi) - math.pi
673679
if isinstance(y, np.ndarray) and len(y) == 1:
674-
return float(y)
680+
return float(y[0])
675681
else:
676682
return y
677683

‎spatialmath/quaternion.py

Copy file name to clipboardExpand all lines: spatialmath/quaternion.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1443,7 +1443,7 @@ def AngVec(
14431443
:seealso: :meth:`UnitQuaternion.angvec` :meth:`UnitQuaternion.exp` :func:`~spatialmath.base.transforms3d.angvec2r`
14441444
"""
14451445
v = smb.getvector(v, 3)
1446-
theta = smb.getunit(theta, unit, dim=0)
1446+
theta = smb.getunit(theta, unit, vector=False)
14471447
return cls(
14481448
s=math.cos(theta / 2), v=math.sin(theta / 2) * v, norm=False, check=False
14491449
)

‎tests/base/test_argcheck.py

Copy file name to clipboardExpand all lines: tests/base/test_argcheck.py
+40-2Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,49 @@ def test_verifymatrix(self):
122122
verifymatrix(a, (3, 4))
123123

124124
def test_unit(self):
125-
self.assertIsInstance(getunit(1), np.ndarray)
125+
# scalar -> vector
126+
self.assertEqual(getunit(1), np.array([1]))
127+
self.assertEqual(getunit(1, dim=0), np.array([1]))
128+
with self.assertRaises(ValueError):
129+
self.assertEqual(getunit(1, dim=1), np.array([1]))
130+
131+
self.assertEqual(getunit(1, unit="deg"), np.array([1 * math.pi / 180.0]))
132+
self.assertEqual(getunit(1, dim=0, unit="deg"), np.array([1 * math.pi / 180.0]))
133+
with self.assertRaises(ValueError):
134+
self.assertEqual(
135+
getunit(1, dim=1, unit="deg"), np.array([1 * math.pi / 180.0])
136+
)
137+
138+
# scalar -> scalar
139+
self.assertEqual(getunit(1, vector=False), 1)
140+
self.assertEqual(getunit(1, dim=0, vector=False), 1)
141+
with self.assertRaises(ValueError):
142+
self.assertEqual(getunit(1, dim=1, vector=False), 1)
143+
144+
self.assertIsInstance(getunit(1.0, vector=False), float)
145+
self.assertIsInstance(getunit(1, vector=False), int)
146+
147+
self.assertEqual(getunit(1, vector=False, unit="deg"), 1 * math.pi / 180.0)
148+
self.assertEqual(
149+
getunit(1, dim=0, vector=False, unit="deg"), 1 * math.pi / 180.0
150+
)
151+
with self.assertRaises(ValueError):
152+
self.assertEqual(
153+
getunit(1, dim=1, vector=False, unit="deg"), 1 * math.pi / 180.0
154+
)
155+
156+
self.assertIsInstance(getunit(1.0, vector=False, unit="deg"), float)
157+
self.assertIsInstance(getunit(1, vector=False, unit="deg"), float)
158+
159+
# vector -> vector
160+
self.assertEqual(getunit([1]), np.array([1]))
161+
self.assertEqual(getunit([1], dim=1), np.array([1]))
162+
with self.assertRaises(ValueError):
163+
getunit([1], dim=0)
164+
126165
self.assertIsInstance(getunit([1, 2]), np.ndarray)
127166
self.assertIsInstance(getunit((1, 2)), np.ndarray)
128167
self.assertIsInstance(getunit(np.r_[1, 2]), np.ndarray)
129-
self.assertIsInstance(getunit(1.0, dim=0), float)
130168

131169
nt.assert_equal(getunit(5, "rad"), 5)
132170
nt.assert_equal(getunit(5, "deg"), 5 * math.pi / 180.0)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.