Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit aec82dc

Browse filesBrowse files
authored
Merge pull request #145 from kushalkolar/better-dtype-conversions
better dtype handling for GPU supported types
2 parents fd417d2 + b5b7da0 commit aec82dc
Copy full SHA for aec82dc

File tree

3 files changed

+9
-6
lines changed
Filter options

3 files changed

+9
-6
lines changed

‎fastplotlib/graphics/features/_base.py

Copy file name to clipboardExpand all lines: fastplotlib/graphics/features/_base.py
+4-3Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020

2121

2222
def to_gpu_supported_dtype(array):
23+
"""
24+
If ``array`` is a numpy array, converts it to a supported type. GPUs don't support 64 bit dtypes.
25+
"""
2326
if isinstance(array, np.ndarray):
2427
if array.dtype not in supported_dtypes:
2528
if np.issubdtype(array.dtype, np.integer):
@@ -69,10 +72,8 @@ def __init__(self, parent, data: Any, collection_index: int = None):
6972
7073
"""
7174
self._parent = parent
72-
if isinstance(data, np.ndarray):
73-
data = to_gpu_supported_dtype(data)
7475

75-
self._data = data
76+
self._data = to_gpu_supported_dtype(data)
7677

7778
self._collection_index = collection_index
7879
self._event_handlers = list()

‎fastplotlib/graphics/features/_colors.py

Copy file name to clipboardExpand all lines: fastplotlib/graphics/features/_colors.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def __setitem__(self, key, value):
206206

207207
n_colors = len(range(key.start, key.stop, key.step))
208208

209-
colors = make_colors(n_colors, cmap=value)
209+
colors = make_colors(n_colors, cmap=value).astype(self._data.dtype)
210210
super(CmapFeature, self).__setitem__(key, colors)
211211

212212

‎fastplotlib/graphics/features/_data.py

Copy file name to clipboardExpand all lines: fastplotlib/graphics/features/_data.py
+4-2Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,23 @@ def __getitem__(self, item):
2525
def _fix_data(self, data, parent):
2626
graphic_type = parent.__class__.__name__
2727

28+
data = to_gpu_supported_dtype(data)
29+
2830
if data.ndim == 1:
2931
# for scatter if we receive just 3 points in a 1d array, treat it as just a single datapoint
3032
# this is different from fix_data for LineGraphic since there we assume that a 1d array
3133
# is just y-values
3234
if graphic_type == "ScatterGraphic":
3335
data = np.array([data])
3436
elif graphic_type == "LineGraphic":
35-
data = np.dstack([np.arange(data.size), data])[0].astype(np.float32)
37+
data = np.dstack([np.arange(data.size, dtype=data.dtype), data])[0]
3638

3739
if data.shape[1] != 3:
3840
if data.shape[1] != 2:
3941
raise ValueError(f"Must pass 1D, 2D or 3D data to {graphic_type}")
4042

4143
# zeros for z
42-
zs = np.zeros(data.shape[0], dtype=np.float32)
44+
zs = np.zeros(data.shape[0], dtype=data.dtype)
4345

4446
data = np.dstack([data[:, 0], data[:, 1], zs])[0]
4547

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.