Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit bc03a94

Browse filesBrowse files
committed
Support imshow(<float16 array>).
1 parent 0ca9d9d commit bc03a94
Copy full SHA for bc03a94

File tree

Expand file treeCollapse file tree

2 files changed

+13
-8
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+13
-8
lines changed

‎lib/matplotlib/image.py

Copy file name to clipboardExpand all lines: lib/matplotlib/image.py
+6-8Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -411,14 +411,12 @@ def _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification=1.0,
411411
# all masked, so values don't matter
412412
a_min, a_max = np.int32(0), np.int32(1)
413413
if inp_dtype.kind == 'f':
414-
scaled_dtype = A.dtype
415-
# Cast to float64
416-
if A.dtype not in (np.float32, np.float16):
417-
if A.dtype != np.float64:
418-
_api.warn_external(
419-
f"Casting input data from '{A.dtype}' to "
420-
f"'float64' for imshow")
421-
scaled_dtype = np.float64
414+
scaled_dtype = np.dtype(
415+
np.float64 if A.dtype.itemsize > 4 else np.float32)
416+
if scaled_dtype.itemsize < A.dtype.itemsize:
417+
_api.warn_external(
418+
f"Casting input data from {A.dtype} to "
419+
f"{scaled_dtype} for imshow")
422420
else:
423421
# probably an integer of some type.
424422
da = a_max.astype(np.float64) - a_min.astype(np.float64)

‎lib/matplotlib/tests/test_image.py

Copy file name to clipboardExpand all lines: lib/matplotlib/tests/test_image.py
+7Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,13 @@ def test_empty_imshow(make_norm):
993993
im.make_image(fig._cachedRenderer)
994994

995995

996+
def test_imshow_float16():
997+
fig, ax = plt.subplots()
998+
ax.imshow(np.zeros((3, 3), dtype=np.float16))
999+
# Ensure that drawing doesn't cause crash.
1000+
fig.canvas.draw()
1001+
1002+
9961003
def test_imshow_float128():
9971004
fig, ax = plt.subplots()
9981005
ax.imshow(np.zeros((3, 3), dtype=np.longdouble))

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.