Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 3851a90

Browse filesBrowse files
jstacmmcky
andauthored
misc (#337)
Co-authored-by: Matt McKay <mmcky@users.noreply.github.com>
1 parent c66c93e commit 3851a90
Copy full SHA for 3851a90

File tree

Expand file treeCollapse file tree

1 file changed

+33
-64
lines changed
Open diff view settings
Filter options
Expand file treeCollapse file tree

1 file changed

+33
-64
lines changed
Open diff view settings
Collapse file

‎lectures/numba.md‎

Copy file name to clipboardExpand all lines: lectures/numba.md
+33-64Lines changed: 33 additions & 64 deletions
  • Display the source diff
  • Display the rich diff
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ jupytext:
44
extension: .md
55
format_name: myst
66
format_version: 0.13
7-
jupytext_version: 1.14.4
7+
jupytext_version: 1.16.7
88
kernelspec:
99
display_name: Python 3 (ipykernel)
1010
language: python
@@ -118,9 +118,9 @@ plt.show()
118118
To speed the function `qm` up using Numba, our first step is
119119

120120
```{code-cell} ipython3
121-
from numba import njit
121+
from numba import jit
122122
123-
qm_numba = njit(qm)
123+
qm_numba = jit(qm)
124124
```
125125

126126
The function `qm_numba` is a version of `qm` that is "targeted" for
@@ -146,7 +146,7 @@ qm_numba(0.1, int(n))
146146
time2 = qe.toc()
147147
```
148148

149-
This is already a massive speed gain.
149+
This is already a very large speed gain.
150150

151151
In fact, the next time and all subsequent times it runs even faster as the function has been compiled and is in memory:
152152

@@ -162,7 +162,7 @@ time3 = qe.toc()
162162
time1 / time3 # Calculate speed gain
163163
```
164164

165-
This kind of speed gain is huge relative to how simple and clear the implementation is.
165+
This kind of speed gain is impressive relative to how simple and clear the modification is.
166166

167167
### How and When it Works
168168

@@ -177,10 +177,10 @@ The basic idea is this:
177177
* Python is very flexible and hence we could call the function qm with many
178178
types.
179179
* e.g., `x0` could be a NumPy array or a list, `n` could be an integer or a float, etc.
180-
* This makes it hard to *pre*-compile the function.
181-
* However, when we do actually call the function, say by executing `qm(0.5, 10)`,
180+
* This makes it hard to *pre*-compile the function (i.e., compile before runtime).
181+
* However, when we do actually call the function, say by running `qm(0.5, 10)`,
182182
the types of `x0` and `n` become clear.
183-
* Moreover, the types of other variables in `qm` can be inferred once the input is known.
183+
* Moreover, the types of other variables in `qm` can be inferred once the input types are known.
184184
* So the strategy of Numba and other JIT compilers is to wait until this
185185
moment, and *then* compile the function.
186186

@@ -190,26 +190,28 @@ Note that, if you make the call `qm(0.5, 10)` and then follow it with `qm(0.9, 2
190190

191191
The compiled code is then cached and recycled as required.
192192

193+
This is why, in the code above, `time3` is smaller than `time2`.
194+
193195
## Decorator Notation
194196

195197
In the code above we created a JIT compiled version of `qm` via the call
196198

197199
```{code-cell} ipython3
198-
qm_numba = njit(qm)
200+
qm_numba = jit(qm)
199201
```
200202

201203
In practice this would typically be done using an alternative *decorator* syntax.
202204

203-
(We will explain all about decorators in a {doc}`later lecture <python_advanced_features>` but you can skip the details at this stage.)
205+
(We discuss decorators in a {doc}`separate lecture <python_advanced_features>` but you can skip the details at this stage.)
204206

205207
Let's see how this is done.
206208

207-
To target a function for JIT compilation we can put `@njit` before the function definition.
209+
To target a function for JIT compilation we can put `@jit` before the function definition.
208210

209211
Here's what this looks like for `qm`
210212

211213
```{code-cell} ipython3
212-
@njit
214+
@jit
213215
def qm(x0, n):
214216
x = np.empty(n+1)
215217
x[0] = x0
@@ -218,7 +220,7 @@ def qm(x0, n):
218220
return x
219221
```
220222

221-
This is equivalent to `qm = njit(qm)`.
223+
This is equivalent to adding `qm = jit(qm)` after the function definition.
222224

223225
The following now uses the jitted version:
224226

@@ -228,13 +230,19 @@ The following now uses the jitted version:
228230
qm(0.1, 100_000)
229231
```
230232

231-
Numba provides several arguments for decorators to accelerate computation and cache functions [here](https://numba.readthedocs.io/en/stable/user/performance-tips.html).
233+
```{code-cell} ipython3
234+
%%time
235+
236+
qm(0.1, 100_000)
237+
```
238+
239+
Numba also provides several arguments for decorators to accelerate computation and cache functions -- see [here](https://numba.readthedocs.io/en/stable/user/performance-tips.html).
232240

233241
In the [following lecture on parallelization](parallel), we will discuss how to use the `parallel` argument to achieve automatic parallelization.
234242

235243
## Type Inference
236244

237-
Clearly type inference is a key part of JIT compilation.
245+
Successful type inference is a key part of JIT compilation.
238246

239247
As you can imagine, inferring types is easier for simple Python objects (e.g., simple scalar data types such as floats and integers).
240248

@@ -248,10 +256,10 @@ In such a setting, Numba will be on par with machine code from low-level languag
248256

249257
When Numba cannot infer all type information, it will raise an error.
250258

251-
For example, in the case below, Numba is unable to determine the type of function `mean` when compiling the function `bootstrap`
259+
For example, in the (artificial) setting below, Numba is unable to determine the type of function `mean` when compiling the function `bootstrap`
252260

253261
```{code-cell} ipython3
254-
@njit
262+
@jit
255263
def bootstrap(data, statistics, n):
256264
bootstrap_stat = np.empty(n)
257265
n = len(data)
@@ -260,69 +268,30 @@ def bootstrap(data, statistics, n):
260268
bootstrap_stat[i] = statistics(resample)
261269
return bootstrap_stat
262270
271+
# No decorator here.
263272
def mean(data):
264273
return np.mean(data)
265274
266-
data = np.array([2.3, 3.1, 4.3, 5.9, 2.1, 3.8, 2.2])
275+
data = np.array((2.3, 3.1, 4.3, 5.9, 2.1, 3.8, 2.2))
267276
n_resamples = 10
268277
269-
print('Type of function:', type(mean))
270-
271-
#Error
278+
# This code throws an error
272279
try:
273280
bootstrap(data, mean, n_resamples)
274281
except Exception as e:
275282
print(e)
276283
```
277284

278-
But Numba recognizes JIT-compiled functions
285+
We can fix this error easily in this case by compiling `mean`.
279286

280287
```{code-cell} ipython3
281-
@njit
288+
@jit
282289
def mean(data):
283290
return np.mean(data)
284291
285-
print('Type of function:', type(mean))
286-
287292
%time bootstrap(data, mean, n_resamples)
288293
```
289294

290-
We can check the signature of the JIT-compiled function
291-
292-
```{code-cell} ipython3
293-
bootstrap.signatures
294-
```
295-
296-
The function `bootstrap` takes one `float64` floating point array, one function called `mean` and an `int64` integer.
297-
298-
Now let's see what happens when we change the inputs.
299-
300-
Running it again with a larger integer for `n` and a different set of data does not change the signature of the function.
301-
302-
```{code-cell} ipython3
303-
data = np.array([4.1, 1.1, 2.3, 1.9, 0.1, 2.8, 1.2])
304-
%time bootstrap(data, mean, 100)
305-
bootstrap.signatures
306-
```
307-
308-
As expected, the second run is much faster.
309-
310-
Let's try to change the data again and use an integer array as data
311-
312-
```{code-cell} ipython3
313-
data = np.array([1, 2, 3, 4, 5], dtype=np.int64)
314-
%time bootstrap(data, mean, 100)
315-
bootstrap.signatures
316-
```
317-
318-
Note that a second signature is added.
319-
320-
It also takes longer to run, suggesting that Numba recompiles this function as the type changes.
321-
322-
Overall, type inference helps Numba to achieve its performance, but it also limits what Numba supports and sometimes requires careful type checks.
323-
324-
You can refer to the list of supported Python and Numpy features [here](https://numba.pydata.org/numba-doc/dev/reference/pysupported.html).
325-
326295
## Compiling Classes
327296

328297
As mentioned above, at present Numba can only compile a subset of Python.
@@ -509,7 +478,7 @@ Consider the following example
509478
```{code-cell} ipython3
510479
a = 1
511480
512-
@njit
481+
@jit
513482
def add_a(x):
514483
return a + x
515484
@@ -549,7 +518,7 @@ Here is one solution:
549518
```{code-cell} ipython3
550519
from random import uniform
551520
552-
@njit
521+
@jit
553522
def calculate_pi(n=1_000_000):
554523
count = 0
555524
for i in range(n):
@@ -677,7 +646,7 @@ qe.toc()
677646
Next let's implement a Numba version, which is easy
678647

679648
```{code-cell} ipython3
680-
compute_series_numba = njit(compute_series)
649+
compute_series_numba = jit(compute_series)
681650
```
682651

683652
Let's check we still get the right numbers

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.