15
15
import numpy as np
16
16
from spatialmath .base import symbolic as sym
17
17
18
+ from spatialmath .base .sm_numba import numba_njit , numba_overload
19
+
18
20
# valid scalar types
19
21
_scalartypes = (int , np .integer , float , np .floating ) + sym .symtype
20
22
21
-
22
- # def getunit(v, unit='rad'):
23
- # """
24
- # Convert value according to angular units
25
-
26
- # :param v: the value in radians or degrees
27
- # :type v: array_like(m) or ndarray(m)
28
- # :param unit: the angular unit, "rad" or "deg"
29
- # :type unit: str
30
- # :return: the converted value in radians
31
- # :rtype: list(m) or ndarray(m)
32
- # :raises ValueError: argument is not a valid angular unit
33
-
34
- # .. runblock:: pycon
35
-
36
- # >>> from spatialmath.base import getunit
37
- # >>> import numpy as np
38
- # >>> getunit(1.5, 'rad')
39
- # >>> getunit(90, 'deg')
40
- # >>> getunit([90, 180], 'deg')
41
- # >>> getunit(np.r_[0.5, 1], 'rad')
42
- # >>> getunit(np.r_[90, 180], 'deg')
43
- # """
44
- # if unit == "rad":
45
- # return v
46
- # elif unit == "deg":
47
- # if isinstance(v, np.ndarray) or np.isscalar(v):
48
- # return v * math.pi / 180
49
- # else:
50
- # return [x * math.pi / 180 for x in v]
51
- # else:
52
- # raise ValueError("invalid angular units")
53
-
54
- # @numba.generated_jit(nopython=True)
55
- # def getunit(v, unit='rad'):
56
- # """
57
- # Convert value according to angular units
58
-
59
- # :param v: the value in radians or degrees
60
- # :type v: array_like(m) or ndarray(m)
61
- # :param unit: the angular unit, "rad" or "deg"
62
- # :type unit: str
63
- # :return: the converted value in radians
64
- # :rtype: list(m) or ndarray(m)
65
- # :raises ValueError: argument is not a valid angular unit
66
-
67
- # .. runblock:: pycon
68
-
69
- # >>> from spatialmath.base import getunit
70
- # >>> import numpy as np
71
- # >>> getunit(1.5, 'rad')
72
- # >>> getunit(90, 'deg')
73
- # >>> getunit([90, 180], 'deg')
74
- # >>> getunit(np.r_[0.5, 1], 'rad')
75
- # >>> getunit(np.r_[90, 180], 'deg')
76
- # """
77
-
78
- # @numba.njit
79
- # def conv(v, unit):
80
- # if unit == "rad":
81
- # return v
82
- # elif unit == "deg":
83
- # return v * math.pi / 180
84
-
85
- # @numba.njit
86
- # def conv_l(v, unit):
87
- # if unit == "rad":
88
- # return v
89
- # elif unit == "deg":
90
- # return [y * np.pi / 180 for y in x]
91
-
92
- # if isinstance(v, np.ndarray) or isinstance(v, numba.types.Float):
93
- # return lambda a, b: conv(a, b)
94
-
95
- # # elif isinstance(v, list):
96
- # # return lambda a, b: conv_l(a, b)
97
-
98
- # # if unit == "rad":
99
- # # # return v
100
- # # return lambda v: v
101
- # # elif unit == "deg":
102
- # # if isinstance(v, np.ndarray) or np.isscalar(v):
103
- # # # return v * math.pi / 180
104
- # # return lambda x: x * np.pi / 180
105
- # # else:
106
- # # # return [x * math.pi / 180 for x in v]
107
- # # return lambda x: [y * np.pi / 180 for y in x]
108
- # # else:
109
- # # raise ValueError("invalid angular units")
110
-
111
23
def isscalar (x ):
112
24
"""
113
25
Test if argument is a real scalar
@@ -537,6 +449,21 @@ def getunit(v, unit='rad'):
537
449
else :
538
450
raise ValueError ("invalid angular units" )
539
451
452
+ @numba_overload (getunit )
453
+ def getunit_numba (v , unit = 'rad' ):
454
+ def n (v , unit = 'rad' ):
455
+ if unit == "rad" :
456
+ return v
457
+ elif unit == "deg" :
458
+ return v * np .pi / 180.0
459
+ else :
460
+ return v
461
+ return n
462
+
463
+ @numba_njit
464
+ def getunit_fast (v , unit = 'rad' ):
465
+ return getunit (v , unit = 'rad' )
466
+
540
467
541
468
def isnumberlist (x ):
542
469
"""
0 commit comments