Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 72d97b3

Browse filesBrowse files
committed
speed
1 parent a29ad8a commit 72d97b3
Copy full SHA for 72d97b3

File tree

Expand file treeCollapse file tree

7 files changed

+185
-13
lines changed
Filter options
Expand file treeCollapse file tree

7 files changed

+185
-13
lines changed

‎spatialmath/base/__init__.py

Copy file name to clipboardExpand all lines: spatialmath/base/__init__.py
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@
1212
from spatialmath.base.animate import * # lgtm [py/polluting-import]
1313
from spatialmath.base.graphics import * # lgtm [py/polluting-import]
1414

15+
# from spatialmath.base.sm_numba import use_numba, numba_njit, using_numba
16+
# from spatialmath.base.sm_numba import *
17+
1518
__all__ = [
19+
# 'use_numba',
20+
# 'numba_njit',
21+
# 'using_numba',
1622
# spatialmath.base.argcheck
1723
'assertmatrix',
1824
'ismatrix',

‎spatialmath/base/argcheck.py

Copy file name to clipboardExpand all lines: spatialmath/base/argcheck.py
+90Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,96 @@
1818
# valid scalar types
1919
_scalartypes = (int, np.integer, float, np.floating) + sym.symtype
2020

21+
22+
# def getunit(v, unit='rad'):
23+
# """
24+
# Convert value according to angular units
25+
26+
# :param v: the value in radians or degrees
27+
# :type v: array_like(m) or ndarray(m)
28+
# :param unit: the angular unit, "rad" or "deg"
29+
# :type unit: str
30+
# :return: the converted value in radians
31+
# :rtype: list(m) or ndarray(m)
32+
# :raises ValueError: argument is not a valid angular unit
33+
34+
# .. runblock:: pycon
35+
36+
# >>> from spatialmath.base import getunit
37+
# >>> import numpy as np
38+
# >>> getunit(1.5, 'rad')
39+
# >>> getunit(90, 'deg')
40+
# >>> getunit([90, 180], 'deg')
41+
# >>> getunit(np.r_[0.5, 1], 'rad')
42+
# >>> getunit(np.r_[90, 180], 'deg')
43+
# """
44+
# if unit == "rad":
45+
# return v
46+
# elif unit == "deg":
47+
# if isinstance(v, np.ndarray) or np.isscalar(v):
48+
# return v * math.pi / 180
49+
# else:
50+
# return [x * math.pi / 180 for x in v]
51+
# else:
52+
# raise ValueError("invalid angular units")
53+
54+
# @numba.generated_jit(nopython=True)
55+
# def getunit(v, unit='rad'):
56+
# """
57+
# Convert value according to angular units
58+
59+
# :param v: the value in radians or degrees
60+
# :type v: array_like(m) or ndarray(m)
61+
# :param unit: the angular unit, "rad" or "deg"
62+
# :type unit: str
63+
# :return: the converted value in radians
64+
# :rtype: list(m) or ndarray(m)
65+
# :raises ValueError: argument is not a valid angular unit
66+
67+
# .. runblock:: pycon
68+
69+
# >>> from spatialmath.base import getunit
70+
# >>> import numpy as np
71+
# >>> getunit(1.5, 'rad')
72+
# >>> getunit(90, 'deg')
73+
# >>> getunit([90, 180], 'deg')
74+
# >>> getunit(np.r_[0.5, 1], 'rad')
75+
# >>> getunit(np.r_[90, 180], 'deg')
76+
# """
77+
78+
# @numba.njit
79+
# def conv(v, unit):
80+
# if unit == "rad":
81+
# return v
82+
# elif unit == "deg":
83+
# return v * math.pi / 180
84+
85+
# @numba.njit
86+
# def conv_l(v, unit):
87+
# if unit == "rad":
88+
# return v
89+
# elif unit == "deg":
90+
# return [y * np.pi / 180 for y in x]
91+
92+
# if isinstance(v, np.ndarray) or isinstance(v, numba.types.Float):
93+
# return lambda a, b: conv(a, b)
94+
95+
# # elif isinstance(v, list):
96+
# # return lambda a, b: conv_l(a, b)
97+
98+
# # if unit == "rad":
99+
# # # return v
100+
# # return lambda v: v
101+
# # elif unit == "deg":
102+
# # if isinstance(v, np.ndarray) or np.isscalar(v):
103+
# # # return v * math.pi / 180
104+
# # return lambda x: x * np.pi / 180
105+
# # else:
106+
# # # return [x * math.pi / 180 for x in v]
107+
# # return lambda x: [y * np.pi / 180 for y in x]
108+
# # else:
109+
# # raise ValueError("invalid angular units")
110+
21111
def isscalar(x):
22112
"""
23113
Test if argument is a real scalar

‎spatialmath/base/quaternions.py

Copy file name to clipboardExpand all lines: spatialmath/base/quaternions.py
+17-9Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@
88
import math
99
import numpy as np
1010
from spatialmath import base
11+
# import spatialmath.base as base
12+
import numba
1113

12-
_eps = np.finfo(np.float64).eps
14+
from spatialmath.base.sm_numba import numba_njit
1315

16+
_eps = np.finfo(np.float64).eps
1417

18+
@numba.njit
1519
def eye():
1620
"""
1721
Create an identity quaternion
@@ -29,7 +33,7 @@ def eye():
2933
>>> qprint(q)
3034
3135
"""
32-
return np.r_[1, 0, 0, 0]
36+
return np.array([1.0, 0.0, 0.0, 0.0])
3337

3438

3539
def pure(v):
@@ -493,6 +497,9 @@ def q2r(q):
493497
[2 * (x * z - s * y), 2 * (y * z + s * x), 1 - 2 * (x ** 2 + y ** 2)]])
494498

495499

500+
# @numba.njit
501+
502+
@numba_njit
496503
def r2q(R, check=False, tol=100):
497504
"""
498505
Convert SO(3) rotation matrix to unit-quaternion
@@ -522,10 +529,11 @@ def r2q(R, check=False, tol=100):
522529
523530
:seealso: :func:`q2r`
524531
"""
525-
if not base.isrot(R, check=check, tol=tol):
526-
raise ValueError("Argument must be a valid SO(3) matrix")
532+
# if not base.isrot(R, check=check, tol=tol):
533+
# print('a')
534+
# raise ValueError("Argument must be a valid SO(3) matrix")
527535

528-
qs = math.sqrt(max(0, np.trace(R) + 1)) / 2.0 # scalar part
536+
qs = np.sqrt(max(0, np.trace(R) + 1)) / 2.0 # scalar part
529537
kx = R[2, 1] - R[1, 2] # Oz - Ay
530538
ky = R[0, 2] - R[2, 0] # Ax - Nz
531539
kz = R[1, 0] - R[0, 1] # Ny - Ox
@@ -555,13 +563,13 @@ def r2q(R, check=False, tol=100):
555563
ky = ky - ky1
556564
kz = kz - kz1
557565

558-
kv = np.r_[kx, ky, kz]
559-
nm = np.linalg.norm(kv)
566+
kv = np.array([qs, kx, ky, kz])
567+
nm = base.norm(kv[1:])
560568
if abs(nm) < tol * _eps:
561569
return eye()
562570
else:
563-
return np.r_[qs, (math.sqrt(1.0 - qs ** 2) / nm) * kv]
564-
571+
kv[1:] = (np.sqrt(1.0 - qs ** 2) / nm) * kv[1:]
572+
return kv
565573

566574
def slerp(q0, q1, s, shortest=False):
567575
"""

‎spatialmath/base/sm_numba.py

Copy file name to clipboard
+43Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
2+
# _numba = True
3+
from functools import wraps
4+
5+
try:
6+
global _numba
7+
import numba
8+
_numba = True
9+
except ImportError:
10+
# global _numba
11+
_numba = False
12+
13+
# Equivalent to @numba.njit
14+
def numba_njit(func):
15+
if _numba:
16+
junc = numba.njit(func)
17+
18+
@wraps(func)
19+
def wrapper(*args, **kwargs):
20+
if _numba:
21+
return junc(*args, **kwargs)
22+
else:
23+
return func(*args, **kwargs)
24+
25+
return wrapper
26+
27+
# Equivalent to @numba.extending.overload(method)
28+
def numba_overload(method):
29+
def decor(func):
30+
if _numba:
31+
return (numba.extending.overload(method))(func)
32+
else:
33+
return func
34+
35+
return decor
36+
37+
def use_numba(use_numba):
38+
global _numba
39+
_numba = use_numba
40+
41+
def using_numba():
42+
global _numba
43+
return _numba

‎spatialmath/base/transforms3d.py

Copy file name to clipboardExpand all lines: spatialmath/base/transforms3d.py
+14-1Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from math import sin, cos
1919
import numpy as np
2020
from spatialmath import base
21+
import numba
2122

2223
_eps = np.finfo(np.float64).eps
2324

@@ -58,6 +59,19 @@ def rotx(theta, unit="rad"):
5859
[0, st, ct]])
5960
return R
6061

62+
@numba.extending.overload(rotx)
63+
def fast_rotx(theta, unit="rad"):
64+
def n(v):
65+
# theta = base.getunit(theta, unit)
66+
ct = np.cos(theta)
67+
st = np.sin(theta)
68+
R = np.array([
69+
[1, 0, 0],
70+
[0, ct, -st],
71+
[0, st, ct]])
72+
return R
73+
return n
74+
6175

6276
# ---------------------------------------------------------------------------------------#
6377
def roty(theta, unit="rad"):
@@ -366,7 +380,6 @@ def isrot(R, check=False, tol=100):
366380
"""
367381
return isinstance(R, np.ndarray) and R.shape == (3, 3) and (not check or base.isR(R, tol=tol))
368382

369-
370383
# ---------------------------------------------------------------------------------------#
371384
def rpy2r(roll, pitch=None, yaw=None, *, unit='rad', order='zyx'):
372385
"""

‎spatialmath/base/transformsNd.py

Copy file name to clipboardExpand all lines: spatialmath/base/transformsNd.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,6 @@ def Ab2M(A, b):
288288

289289
# ======================= predicates
290290

291-
292291
def isR(R, tol=100):
293292
r"""
294293
Test if matrix belongs to SO(n)
@@ -312,8 +311,9 @@ def isR(R, tol=100):
312311
313312
:seealso: isrot2, isrot
314313
"""
315-
return np.linalg.norm(R@R.T - np.eye(R.shape[0])) < tol * _eps \
316-
and np.linalg.det(R@R.T) > 0
314+
pre = R@R.T
315+
return np.linalg.norm(pre - np.eye(R.shape[0])) < tol * _eps \
316+
and np.linalg.det(pre) > 0
317317

318318

319319
def isskew(S, tol=10):

‎spatialmath/base/vectors.py

Copy file name to clipboardExpand all lines: spatialmath/base/vectors.py
+12Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import numpy as np
1616
from spatialmath.base import getvector
1717

18+
from spatialmath.base.sm_numba import numba_njit, numba_overload
19+
1820
try: # pragma: no cover
1921
# print('Using SymPy')
2022
import sympy
@@ -26,6 +28,7 @@
2628

2729
_eps = np.finfo(np.float64).eps
2830

31+
import numba
2932

3033
def colvec(v):
3134
"""
@@ -138,6 +141,15 @@ def norm(v):
138141
else:
139142
return math.sqrt(sum)
140143

144+
@numba_overload(norm)
145+
def norm_fast(v):
146+
def n(v):
147+
sum = 0
148+
for x in v:
149+
sum += x * x
150+
return np.sqrt(sum)
151+
return n
152+
141153
def normsq(v):
142154
"""
143155
Squared norm of vector

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.