Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit be0b0ee

Browse filesBrowse files
committed
numba decor
1 parent 72d97b3 commit be0b0ee
Copy full SHA for be0b0ee

File tree

Expand file treeCollapse file tree

6 files changed

+209
-182
lines changed
Filter options
Expand file treeCollapse file tree

6 files changed

+209
-182
lines changed

‎spatialmath/base/argcheck.py

Copy file name to clipboardExpand all lines: spatialmath/base/argcheck.py
+17-90Lines changed: 17 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -15,99 +15,11 @@
1515
import numpy as np
1616
from spatialmath.base import symbolic as sym
1717

18+
from spatialmath.base.sm_numba import numba_njit, numba_overload
19+
1820
# valid scalar types
1921
_scalartypes = (int, np.integer, float, np.floating) + sym.symtype
2022

21-
22-
# def getunit(v, unit='rad'):
23-
# """
24-
# Convert value according to angular units
25-
26-
# :param v: the value in radians or degrees
27-
# :type v: array_like(m) or ndarray(m)
28-
# :param unit: the angular unit, "rad" or "deg"
29-
# :type unit: str
30-
# :return: the converted value in radians
31-
# :rtype: list(m) or ndarray(m)
32-
# :raises ValueError: argument is not a valid angular unit
33-
34-
# .. runblock:: pycon
35-
36-
# >>> from spatialmath.base import getunit
37-
# >>> import numpy as np
38-
# >>> getunit(1.5, 'rad')
39-
# >>> getunit(90, 'deg')
40-
# >>> getunit([90, 180], 'deg')
41-
# >>> getunit(np.r_[0.5, 1], 'rad')
42-
# >>> getunit(np.r_[90, 180], 'deg')
43-
# """
44-
# if unit == "rad":
45-
# return v
46-
# elif unit == "deg":
47-
# if isinstance(v, np.ndarray) or np.isscalar(v):
48-
# return v * math.pi / 180
49-
# else:
50-
# return [x * math.pi / 180 for x in v]
51-
# else:
52-
# raise ValueError("invalid angular units")
53-
54-
# @numba.generated_jit(nopython=True)
55-
# def getunit(v, unit='rad'):
56-
# """
57-
# Convert value according to angular units
58-
59-
# :param v: the value in radians or degrees
60-
# :type v: array_like(m) or ndarray(m)
61-
# :param unit: the angular unit, "rad" or "deg"
62-
# :type unit: str
63-
# :return: the converted value in radians
64-
# :rtype: list(m) or ndarray(m)
65-
# :raises ValueError: argument is not a valid angular unit
66-
67-
# .. runblock:: pycon
68-
69-
# >>> from spatialmath.base import getunit
70-
# >>> import numpy as np
71-
# >>> getunit(1.5, 'rad')
72-
# >>> getunit(90, 'deg')
73-
# >>> getunit([90, 180], 'deg')
74-
# >>> getunit(np.r_[0.5, 1], 'rad')
75-
# >>> getunit(np.r_[90, 180], 'deg')
76-
# """
77-
78-
# @numba.njit
79-
# def conv(v, unit):
80-
# if unit == "rad":
81-
# return v
82-
# elif unit == "deg":
83-
# return v * math.pi / 180
84-
85-
# @numba.njit
86-
# def conv_l(v, unit):
87-
# if unit == "rad":
88-
# return v
89-
# elif unit == "deg":
90-
# return [y * np.pi / 180 for y in x]
91-
92-
# if isinstance(v, np.ndarray) or isinstance(v, numba.types.Float):
93-
# return lambda a, b: conv(a, b)
94-
95-
# # elif isinstance(v, list):
96-
# # return lambda a, b: conv_l(a, b)
97-
98-
# # if unit == "rad":
99-
# # # return v
100-
# # return lambda v: v
101-
# # elif unit == "deg":
102-
# # if isinstance(v, np.ndarray) or np.isscalar(v):
103-
# # # return v * math.pi / 180
104-
# # return lambda x: x * np.pi / 180
105-
# # else:
106-
# # # return [x * math.pi / 180 for x in v]
107-
# # return lambda x: [y * np.pi / 180 for y in x]
108-
# # else:
109-
# # raise ValueError("invalid angular units")
110-
11123
def isscalar(x):
11224
"""
11325
Test if argument is a real scalar
@@ -537,6 +449,21 @@ def getunit(v, unit='rad'):
537449
else:
538450
raise ValueError("invalid angular units")
539451

452+
@numba_overload(getunit)
453+
def getunit_numba(v, unit='rad'):
454+
def n(v, unit='rad'):
455+
if unit == "rad":
456+
return v
457+
elif unit == "deg":
458+
return v * np.pi / 180.0
459+
else:
460+
return v
461+
return n
462+
463+
@numba_njit
464+
def getunit_fast(v, unit='rad'):
465+
return getunit(v, unit='rad')
466+
540467

541468
def isnumberlist(x):
542469
"""

‎spatialmath/base/sm_numba.py

Copy file name to clipboard
+15-8Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
# _numba = True
33
from functools import wraps
4+
import os
45

56
try:
67
global _numba
@@ -10,19 +11,23 @@
1011
# global _numba
1112
_numba = False
1213

14+
1315
# Equivalent to @numba.njit
1416
def numba_njit(func):
1517
if _numba:
16-
junc = numba.njit(func)
18+
return (numba.njit(cache=True))(func)
19+
else:
20+
return func
1721

18-
@wraps(func)
19-
def wrapper(*args, **kwargs):
20-
if _numba:
21-
return junc(*args, **kwargs)
22-
else:
23-
return func(*args, **kwargs)
22+
# @wraps(func)
23+
# def wrapper(*args, **kwargs):
24+
# if _numba:
25+
# return junc(*args, **kwargs)
26+
# else:
27+
# return func(*args, **kwargs)
28+
29+
# return wrapper
2430

25-
return wrapper
2631

2732
# Equivalent to @numba.extending.overload(method)
2833
def numba_overload(method):
@@ -34,10 +39,12 @@ def decor(func):
3439

3540
return decor
3641

42+
3743
def use_numba(use_numba):
3844
global _numba
3945
_numba = use_numba
4046

47+
4148
def using_numba():
4249
global _numba
4350
return _numba

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.