Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit efe322b

Browse filesBrowse files
HumphreyYangmmcky
andauthored
FIX: Update Numba Lecture to Address Deprecation of @jit (#296)
* update a section on type inference. * update lecture to avoid literal box warning * check the type of the function * Update lectures/numba.md Co-authored-by: mmcky <mmcky@users.noreply.github.com> * reduce redundancy * further simplifies descriptions * fix typos --------- Co-authored-by: mmcky <mmcky@users.noreply.github.com>
1 parent 995c490 commit efe322b
Copy full SHA for efe322b

File tree

Expand file treeCollapse file tree

1 file changed

+118
-60
lines changed
Open diff view settings
Filter options
Expand file treeCollapse file tree

1 file changed

+118
-60
lines changed
Open diff view settings
Collapse file

‎lectures/numba.md‎

Copy file name to clipboardExpand all lines: lectures/numba.md
+118-60Lines changed: 118 additions & 60 deletions
  • Display the source diff
  • Display the rich diff
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ jupytext:
33
text_representation:
44
extension: .md
55
format_name: myst
6+
format_version: 0.13
7+
jupytext_version: 1.14.4
68
kernelspec:
7-
display_name: Python 3
9+
display_name: Python 3 (ipykernel)
810
language: python
911
name: python3
1012
---
@@ -26,10 +28,9 @@ kernelspec:
2628

2729
In addition to what's in Anaconda, this lecture will need the following libraries:
2830

29-
```{code-cell} ipython
30-
---
31-
tags: [hide-output]
32-
---
31+
```{code-cell} ipython3
32+
:tags: [hide-output]
33+
3334
!pip install quantecon
3435
```
3536

@@ -38,7 +39,7 @@ versions are a {doc}`common source of errors <troubleshooting>`.
3839

3940
Let's start with some imports:
4041

41-
```{code-cell} ipython
42+
```{code-cell} ipython3
4243
%matplotlib inline
4344
import numpy as np
4445
import quantecon as qe
@@ -98,13 +99,13 @@ $$
9899

99100
In what follows we set
100101

101-
```{code-cell} python3
102+
```{code-cell} ipython3
102103
α = 4.0
103104
```
104105

105106
Here's the plot of a typical trajectory, starting from $x_0 = 0.1$, with $t$ on the x-axis
106107

107-
```{code-cell} python3
108+
```{code-cell} ipython3
108109
def qm(x0, n):
109110
x = np.empty(n+1)
110111
x[0] = x0
@@ -122,10 +123,10 @@ plt.show()
122123

123124
To speed the function `qm` up using Numba, our first step is
124125

125-
```{code-cell} python3
126-
from numba import jit
126+
```{code-cell} ipython3
127+
from numba import njit
127128
128-
qm_numba = jit(qm)
129+
qm_numba = njit(qm)
129130
```
130131

131132
The function `qm_numba` is a version of `qm` that is "targeted" for
@@ -135,7 +136,7 @@ We will explain what this means momentarily.
135136

136137
Let's time and compare identical function calls across these two versions, starting with the original function `qm`:
137138

138-
```{code-cell} python3
139+
```{code-cell} ipython3
139140
n = 10_000_000
140141
141142
qe.tic()
@@ -145,7 +146,7 @@ time1 = qe.toc()
145146

146147
Now let's try qm_numba
147148

148-
```{code-cell} python3
149+
```{code-cell} ipython3
149150
qe.tic()
150151
qm_numba(0.1, int(n))
151152
time2 = qe.toc()
@@ -156,13 +157,14 @@ This is already a massive speed gain.
156157
In fact, the next time and all subsequent times it runs even faster as the function has been compiled and is in memory:
157158

158159
(qm_numba_result)=
159-
```{code-cell} python3
160+
161+
```{code-cell} ipython3
160162
qe.tic()
161163
qm_numba(0.1, int(n))
162164
time3 = qe.toc()
163165
```
164166

165-
```{code-cell} python3
167+
```{code-cell} ipython3
166168
time1 / time3 # Calculate speed gain
167169
```
168170

@@ -194,12 +196,12 @@ Note that, if you make the call `qm(0.5, 10)` and then follow it with `qm(0.9, 2
194196

195197
The compiled code is then cached and recycled as required.
196198

197-
## Decorators and "nopython" Mode
199+
## Decorator Notation
198200

199201
In the code above we created a JIT compiled version of `qm` via the call
200202

201-
```{code-cell} python3
202-
qm_numba = jit(qm)
203+
```{code-cell} ipython3
204+
qm_numba = njit(qm)
203205
```
204206

205207
In practice this would typically be done using an alternative *decorator* syntax.
@@ -208,14 +210,12 @@ In practice this would typically be done using an alternative *decorator* syntax
208210

209211
Let's see how this is done.
210212

211-
### Decorator Notation
212-
213-
To target a function for JIT compilation we can put `@jit` before the function definition.
213+
To target a function for JIT compilation we can put `@njit` before the function definition.
214214

215215
Here's what this looks like for `qm`
216216

217-
```{code-cell} python3
218-
@jit
217+
```{code-cell} ipython3
218+
@njit
219219
def qm(x0, n):
220220
x = np.empty(n+1)
221221
x[0] = x0
@@ -224,15 +224,21 @@ def qm(x0, n):
224224
return x
225225
```
226226

227-
This is equivalent to `qm = jit(qm)`.
227+
This is equivalent to `qm = njit(qm)`.
228228

229229
The following now uses the jitted version:
230230

231-
```{code-cell} python3
232-
qm(0.1, 10)
231+
```{code-cell} ipython3
232+
%%time
233+
234+
qm(0.1, 100_000)
233235
```
234236

235-
### Type Inference and "nopython" Mode
237+
Numba provides several arguments for decorators to accelerate computation and cache functions [here](https://numba.readthedocs.io/en/stable/user/performance-tips.html).
238+
239+
In the [following lecture on parallelization](parallel), we will discuss how to use the `parallel` argument to achieve automatic parallelization.
240+
241+
## Type Inference
236242

237243
Clearly type inference is a key part of JIT compilation.
238244

@@ -246,29 +252,83 @@ This allows it to generate native machine code, without having to call the Pytho
246252

247253
In such a setting, Numba will be on par with machine code from low-level languages.
248254

249-
When Numba cannot infer all type information, some Python objects are given generic object status and execution falls back to the Python runtime.
255+
When Numba cannot infer all type information, it will raise an error.
250256

251-
When this happens, Numba provides only minor speed gains or none at all.
257+
For example, in the case below, Numba is unable to determine the type of function `mean` when compiling the function `bootstrap`
252258

253-
We generally prefer to force an error when this occurs, so we know effective
254-
compilation is failing.
259+
```{code-cell} ipython3
260+
@njit
261+
def bootstrap(data, statistics, n):
262+
bootstrap_stat = np.empty(n)
263+
n = len(data)
264+
for i in range(n_resamples):
265+
resample = np.random.choice(data, size=n, replace=True)
266+
bootstrap_stat[i] = statistics(resample)
267+
return bootstrap_stat
255268
256-
This is done by using either `@jit(nopython=True)` or, equivalently, `@njit` instead of `@jit`.
269+
def mean(data):
270+
return np.mean(data)
257271
258-
For example,
272+
data = np.array([2.3, 3.1, 4.3, 5.9, 2.1, 3.8, 2.2])
273+
n_resamples = 10
259274
260-
```{code-cell} python3
261-
from numba import njit
275+
print('Type of function:', type(mean))
276+
277+
#Error
278+
try:
279+
bootstrap(data, mean, n_resamples)
280+
except Exception as e:
281+
print(e)
282+
```
262283

284+
But Numba recognizes JIT-compiled functions
285+
286+
```{code-cell} ipython3
263287
@njit
264-
def qm(x0, n):
265-
x = np.empty(n+1)
266-
x[0] = x0
267-
for t in range(n):
268-
x[t+1] = 4 * x[t] * (1 - x[t])
269-
return x
288+
def mean(data):
289+
return np.mean(data)
290+
291+
print('Type of function:', type(mean))
292+
293+
%time bootstrap(data, mean, n_resamples)
294+
```
295+
296+
We can check the signature of the JIT-compiled function
297+
298+
```{code-cell} ipython3
299+
bootstrap.signatures
300+
```
301+
302+
The function `bootstrap` takes one `float64` floating point array, one function called `mean` and an `int64` integer.
303+
304+
Now let's see what happens when we change the inputs.
305+
306+
Running it again with a larger integer for `n` and a different set of data does not change the signature of the function.
307+
308+
```{code-cell} ipython3
309+
data = np.array([4.1, 1.1, 2.3, 1.9, 0.1, 2.8, 1.2])
310+
%time bootstrap(data, mean, 100)
311+
bootstrap.signatures
270312
```
271313

314+
As expected, the second run is much faster.
315+
316+
Let's try to change the data again and use an integer array as data
317+
318+
```{code-cell} ipython3
319+
data = np.array([1, 2, 3, 4, 5], dtype=np.int64)
320+
%time bootstrap(data, mean, 100)
321+
bootstrap.signatures
322+
```
323+
324+
Note that a second signature is added.
325+
326+
It also takes longer to run, suggesting that Numba recompiles this function as the type changes.
327+
328+
Overall, type inference helps Numba to achieve its performance, but it also limits what Numba supports and sometimes requires careful type checks.
329+
330+
You can refer to the list of supported Python and Numpy features [here](https://numba.pydata.org/numba-doc/dev/reference/pysupported.html).
331+
272332
## Compiling Classes
273333

274334
As mentioned above, at present Numba can only compile a subset of Python.
@@ -285,7 +345,7 @@ created in {doc}`this lecture <python_oop>`.
285345

286346
To compile this class we use the `@jitclass` decorator:
287347

288-
```{code-cell} python3
348+
```{code-cell} ipython3
289349
from numba import float64
290350
from numba.experimental import jitclass
291351
```
@@ -294,11 +354,11 @@ Notice that we also imported something called `float64`.
294354

295355
This is a data type representing standard floating point numbers.
296356

297-
We are importing it here because Numba needs a bit of extra help with types when it trys to deal with classes.
357+
We are importing it here because Numba needs a bit of extra help with types when it tries to deal with classes.
298358

299359
Here's our code:
300360

301-
```{code-cell} python3
361+
```{code-cell} ipython3
302362
solow_data = [
303363
('n', float64),
304364
('s', float64),
@@ -361,7 +421,7 @@ After that, targeting the class for JIT compilation only requires adding
361421

362422
When we call the methods in the class, the methods are compiled just like functions.
363423

364-
```{code-cell} python3
424+
```{code-cell} ipython3
365425
s1 = Solow()
366426
s2 = Solow(k=8.0)
367427
@@ -444,25 +504,25 @@ For larger ones, or for routines using external libraries, it can easily fail.
444504

445505
Hence, it's prudent when using Numba to focus on speeding up small, time-critical snippets of code.
446506

447-
This will give you much better performance than blanketing your Python programs with `@jit` statements.
507+
This will give you much better performance than blanketing your Python programs with `@njit` statements.
448508

449509
### A Gotcha: Global Variables
450510

451511
Here's another thing to be careful about when using Numba.
452512

453513
Consider the following example
454514

455-
```{code-cell} python3
515+
```{code-cell} ipython3
456516
a = 1
457517
458-
@jit
518+
@njit
459519
def add_a(x):
460520
return a + x
461521
462522
print(add_a(10))
463523
```
464524

465-
```{code-cell} python3
525+
```{code-cell} ipython3
466526
a = 2
467527
468528
print(add_a(10))
@@ -492,7 +552,7 @@ Compare speed with and without Numba when the sample size is large.
492552

493553
Here is one solution:
494554

495-
```{code-cell} python3
555+
```{code-cell} ipython3
496556
from random import uniform
497557
498558
@njit
@@ -581,13 +641,13 @@ We let
581641
- 0 represent "low"
582642
- 1 represent "high"
583643

584-
```{code-cell} python3
644+
```{code-cell} ipython3
585645
p, q = 0.1, 0.2 # Prob of leaving low and high state respectively
586646
```
587647

588648
Here's a pure Python version of the function
589649

590-
```{code-cell} python3
650+
```{code-cell} ipython3
591651
def compute_series(n):
592652
x = np.empty(n, dtype=np.int_)
593653
x[0] = 1 # Start in state 1
@@ -604,7 +664,7 @@ def compute_series(n):
604664
Let's run this code and check that the fraction of time spent in the low
605665
state is about 0.666
606666

607-
```{code-cell} python3
667+
```{code-cell} ipython3
608668
n = 1_000_000
609669
x = compute_series(n)
610670
print(np.mean(x == 0)) # Fraction of time x is in state 0
@@ -614,30 +674,28 @@ This is (approximately) the right output.
614674

615675
Now let's time it:
616676

617-
```{code-cell} python3
677+
```{code-cell} ipython3
618678
qe.tic()
619679
compute_series(n)
620680
qe.toc()
621681
```
622682

623683
Next let's implement a Numba version, which is easy
624684

625-
```{code-cell} python3
626-
from numba import jit
627-
628-
compute_series_numba = jit(compute_series)
685+
```{code-cell} ipython3
686+
compute_series_numba = njit(compute_series)
629687
```
630688

631689
Let's check we still get the right numbers
632690

633-
```{code-cell} python3
691+
```{code-cell} ipython3
634692
x = compute_series_numba(n)
635693
print(np.mean(x == 0))
636694
```
637695

638696
Let's see the time
639697

640-
```{code-cell} python3
698+
```{code-cell} ipython3
641699
qe.tic()
642700
compute_series_numba(n)
643701
qe.toc()

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.