Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit fd417d2

Browse filesBrowse files
authored
Merge pull request #143 from kushalkolar/large-images
HeatmapGraphic, supports dims larger than 8192
2 parents 70b4908 + e464925 commit fd417d2
Copy full SHA for fd417d2

File tree

7 files changed

+268
-19
lines changed
Filter options

7 files changed

+268
-19
lines changed

‎fastplotlib/graphics/__init__.py

Copy file name to clipboardExpand all lines: fastplotlib/graphics/__init__.py
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from .histogram import HistogramGraphic
22
from .line import LineGraphic
33
from .scatter import ScatterGraphic
4-
from .image import ImageGraphic
5-
from .heatmap import HeatmapGraphic
4+
from .image import ImageGraphic, HeatmapGraphic
5+
# from .heatmap import HeatmapGraphic
66
from .text import TextGraphic
77
from .line_collection import LineCollection, LineStack
88

+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from ._colors import ColorFeature, CmapFeature, ImageCmapFeature
2-
from ._data import PointsDataFeature, ImageDataFeature
1+
from ._colors import ColorFeature, CmapFeature, ImageCmapFeature, HeatmapCmapFeature
2+
from ._data import PointsDataFeature, ImageDataFeature, HeatmapDataFeature
33
from ._present import PresentFeature
44
from ._thickness import ThicknessFeature
55
from ._base import GraphicFeature, GraphicFeatureIndexable

‎fastplotlib/graphics/features/_base.py

Copy file name to clipboardExpand all lines: fastplotlib/graphics/features/_base.py
+29-1Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,33 @@
77
from pygfx import Buffer
88

99

10+
supported_dtypes = [
11+
np.uint8,
12+
np.uint16,
13+
np.uint32,
14+
np.int8,
15+
np.int16,
16+
np.int32,
17+
np.float16,
18+
np.float32
19+
]
20+
21+
22+
def to_gpu_supported_dtype(array):
23+
if isinstance(array, np.ndarray):
24+
if array.dtype not in supported_dtypes:
25+
if np.issubdtype(array.dtype, np.integer):
26+
warn(f"converting {array.dtype} array to int32")
27+
return array.astype(np.int32)
28+
elif np.issubdtype(array.dtype, np.floating):
29+
warn(f"converting {array.dtype} array to float32")
30+
return array.astype(np.float32, copy=False)
31+
else:
32+
raise TypeError("Unsupported type, supported array types must be int or float dtypes")
33+
34+
return array
35+
36+
1037
class FeatureEvent:
1138
"""
1239
type: <feature_name>, example: "colors"
@@ -43,7 +70,7 @@ def __init__(self, parent, data: Any, collection_index: int = None):
4370
"""
4471
self._parent = parent
4572
if isinstance(data, np.ndarray):
46-
data = data.astype(np.float32)
73+
data = to_gpu_supported_dtype(data)
4774

4875
self._data = data
4976

@@ -227,3 +254,4 @@ def _update_range_indices(self, key):
227254
self._buffer.update_range(ix, size=1)
228255
else:
229256
raise TypeError("must pass int or slice to update range")
257+

‎fastplotlib/graphics/features/_colors.py

Copy file name to clipboardExpand all lines: fastplotlib/graphics/features/_colors.py
+13Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,16 @@ def _feature_changed(self, key, new_data):
238238
event_data = FeatureEvent(type="cmap", pick_info=pick_info)
239239

240240
self._call_event_handlers(event_data)
241+
242+
243+
class HeatmapCmapFeature(ImageCmapFeature):
244+
"""
245+
Colormap for HeatmapGraphic
246+
"""
247+
248+
def _set(self, cmap_name: str):
249+
self._parent._material.map.texture.data[:] = make_colors(256, cmap_name)
250+
self._parent._material.map.texture.update_range((0, 0, 0), size=(256, 1, 1))
251+
self.name = cmap_name
252+
253+
self._feature_changed(key=None, new_data=self.name)

‎fastplotlib/graphics/features/_data.py

Copy file name to clipboardExpand all lines: fastplotlib/graphics/features/_data.py
+47-10Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,7 @@
33
import numpy as np
44
from pygfx import Buffer, Texture
55

6-
from ._base import GraphicFeatureIndexable, cleanup_slice, FeatureEvent
7-
8-
9-
def to_float32(array):
10-
if isinstance(array, np.ndarray):
11-
return array.astype(np.float32, copy=False)
12-
13-
return array
6+
from ._base import GraphicFeatureIndexable, cleanup_slice, FeatureEvent, to_gpu_supported_dtype
147

158

169
class PointsDataFeature(GraphicFeatureIndexable):
@@ -102,7 +95,7 @@ def __init__(self, parent, data: Any):
10295
"``[x_dim, y_dim]`` or ``[x_dim, y_dim, rgb]``"
10396
)
10497

105-
data = to_float32(data)
98+
data = to_gpu_supported_dtype(data)
10699
super(ImageDataFeature, self).__init__(parent, data)
107100

108101
@property
@@ -114,7 +107,7 @@ def __getitem__(self, item):
114107

115108
def __setitem__(self, key, value):
116109
# make sure float32
117-
value = to_float32(value)
110+
value = to_gpu_supported_dtype(value)
118111

119112
self._buffer.data[key] = value
120113
self._update_range(key)
@@ -145,3 +138,47 @@ def _feature_changed(self, key, new_data):
145138
event_data = FeatureEvent(type="data", pick_info=pick_info)
146139

147140
self._call_event_handlers(event_data)
141+
142+
143+
class HeatmapDataFeature(ImageDataFeature):
144+
@property
145+
def _buffer(self) -> List[Texture]:
146+
return [img.geometry.grid.texture for img in self._parent.world_object.children]
147+
148+
def __getitem__(self, item):
149+
return self._data[item]
150+
151+
def __setitem__(self, key, value):
152+
# make sure supported type, not float64 etc.
153+
value = to_gpu_supported_dtype(value)
154+
155+
self._data[key] = value
156+
self._update_range(key)
157+
158+
# avoid creating dicts constantly if there are no events to handle
159+
if len(self._event_handlers) > 0:
160+
self._feature_changed(key, value)
161+
162+
def _update_range(self, key):
163+
for buffer in self._buffer:
164+
buffer.update_range((0, 0, 0), size=buffer.size)
165+
166+
def _feature_changed(self, key, new_data):
167+
if key is not None:
168+
key = cleanup_slice(key, self._upper_bound)
169+
if isinstance(key, int):
170+
indices = [key]
171+
elif isinstance(key, slice):
172+
indices = range(key.start, key.stop, key.step)
173+
elif key is None:
174+
indices = None
175+
176+
pick_info = {
177+
"index": indices,
178+
"world_object": self._parent.world_object,
179+
"new_data": new_data
180+
}
181+
182+
event_data = FeatureEvent(type="data", pick_info=pick_info)
183+
184+
self._call_event_handlers(event_data)

‎fastplotlib/graphics/image.py

Copy file name to clipboardExpand all lines: fastplotlib/graphics/image.py
+175-1Lines changed: 175 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from typing import *
2+
from math import ceil
3+
from itertools import product
24

35
import pygfx
6+
from pygfx.utils import unpack_bitfield
47

58
from ._base import Graphic, Interaction, PreviouslyModifiedData
6-
from .features import ImageCmapFeature, ImageDataFeature
9+
from .features import ImageCmapFeature, ImageDataFeature, HeatmapDataFeature, HeatmapCmapFeature
710
from ..utils import quick_min_max
811

912

@@ -119,5 +122,176 @@ def _reset_feature(self, feature: str):
119122
pass
120123

121124

125+
class _ImageTile(pygfx.Image):
126+
"""
127+
Similar to pygfx.Image, only difference is that it contains a few properties to keep track of
128+
row chunk index, column chunk index
122129
123130
131+
"""
132+
def _wgpu_get_pick_info(self, pick_value):
133+
tex = self.geometry.grid
134+
if hasattr(tex, "texture"):
135+
tex = tex.texture # tex was a view
136+
# This should match with the shader
137+
values = unpack_bitfield(pick_value, wobject_id=20, x=22, y=22)
138+
x = values["x"] / 4194304 * tex.size[0] - 0.5
139+
y = values["y"] / 4194304 * tex.size[1] - 0.5
140+
ix, iy = int(x + 0.5), int(y + 0.5)
141+
return {
142+
"index": (ix, iy),
143+
"pixel_coord": (x - ix, y - iy),
144+
"row_chunk_index": self.row_chunk_index,
145+
"col_chunk_index": self.col_chunk_index
146+
}
147+
148+
@property
149+
def row_chunk_index(self) -> int:
150+
return self._row_chunk_index
151+
152+
@row_chunk_index.setter
153+
def row_chunk_index(self, index: int):
154+
self._row_chunk_index = index
155+
156+
@property
157+
def col_chunk_index(self) -> int:
158+
return self._col_chunk_index
159+
160+
@col_chunk_index.setter
161+
def col_chunk_index(self, index: int):
162+
self._col_chunk_index = index
163+
164+
165+
class HeatmapGraphic(Graphic, Interaction):
166+
feature_events = (
167+
"data",
168+
"cmap",
169+
)
170+
171+
def __init__(
172+
self,
173+
data: Any,
174+
vmin: int = None,
175+
vmax: int = None,
176+
cmap: str = 'plasma',
177+
filter: str = "nearest",
178+
chunk_size: int = 8192,
179+
*args,
180+
**kwargs
181+
):
182+
"""
183+
Create an Image Graphic
184+
185+
Parameters
186+
----------
187+
data: array-like
188+
array-like, usually numpy.ndarray, must support ``memoryview()``
189+
Tensorflow Tensors also work **probably**, but not thoroughly tested
190+
| shape must be ``[x_dim, y_dim]``
191+
vmin: int, optional
192+
minimum value for color scaling, calculated from data if not provided
193+
vmax: int, optional
194+
maximum value for color scaling, calculated from data if not provided
195+
cmap: str, optional, default "plasma"
196+
colormap to use to display the data
197+
filter: str, optional, default "nearest"
198+
interpolation filter, one of "nearest" or "linear"
199+
chunk_size: int, default 8192, max 8192
200+
chunk size for each tile used to make up the heatmap texture
201+
args:
202+
additional arguments passed to Graphic
203+
kwargs:
204+
additional keyword arguments passed to Graphic
205+
206+
Examples
207+
--------
208+
.. code-block:: python
209+
210+
from fastplotlib import Plot
211+
# create a `Plot` instance
212+
plot = Plot()
213+
# make some random 2D image data
214+
data = np.random.rand(512, 512)
215+
# plot the image data
216+
plot.add_image(data=data)
217+
# show the plot
218+
plot.show()
219+
"""
220+
221+
super().__init__(*args, **kwargs)
222+
223+
if chunk_size > 8192:
224+
raise ValueError("Maximum chunk size is 8192")
225+
226+
self.data = HeatmapDataFeature(self, data)
227+
228+
row_chunks = range(ceil(data.shape[0] / chunk_size))
229+
col_chunks = range(ceil(data.shape[1] / chunk_size))
230+
231+
chunks = list(product(row_chunks, col_chunks))
232+
# chunks is the index position of each chunk
233+
234+
start_ixs = [list(map(lambda c: c * chunk_size, chunk)) for chunk in chunks]
235+
stop_ixs = [list(map(lambda c: c + chunk_size, chunk)) for chunk in start_ixs]
236+
237+
self._world_object = pygfx.Group()
238+
239+
if (vmin is None) or (vmax is None):
240+
vmin, vmax = quick_min_max(data)
241+
242+
self.cmap = HeatmapCmapFeature(self, cmap)
243+
self._material = pygfx.ImageBasicMaterial(clim=(vmin, vmax), map=self.cmap())
244+
245+
for start, stop, chunk in zip(start_ixs, stop_ixs, chunks):
246+
row_start, col_start = start
247+
row_stop, col_stop = stop
248+
249+
# x and y positions of the Tile in world space coordinates
250+
y_pos, x_pos = row_start, col_start
251+
252+
tex_view = pygfx.Texture(data[row_start:row_stop, col_start:col_stop], dim=2).get_view(filter=filter)
253+
geometry = pygfx.Geometry(grid=tex_view)
254+
# material = pygfx.ImageBasicMaterial(clim=(0, 1), map=self.cmap())
255+
256+
img = _ImageTile(geometry, self._material)
257+
258+
# row and column chunk index for this Tile
259+
img.row_chunk_index = chunk[0]
260+
img.col_chunk_index = chunk[1]
261+
262+
img.position.set_x(x_pos)
263+
img.position.set_y(y_pos)
264+
265+
self.world_object.add(img)
266+
267+
@property
268+
def vmin(self) -> float:
269+
"""Minimum contrast limit."""
270+
return self._material.clim[0]
271+
272+
@vmin.setter
273+
def vmin(self, value: float):
274+
"""Minimum contrast limit."""
275+
self._material.clim = (
276+
value,
277+
self._material.clim[1]
278+
)
279+
280+
@property
281+
def vmax(self) -> float:
282+
"""Maximum contrast limit."""
283+
return self._material.clim[1]
284+
285+
@vmax.setter
286+
def vmax(self, value: float):
287+
"""Maximum contrast limit."""
288+
self._material.clim = (
289+
self._material.clim[0],
290+
value
291+
)
292+
293+
def _set_feature(self, feature: str, new_data: Any, indices: Any):
294+
pass
295+
296+
def _reset_feature(self, feature: str):
297+
pass

‎fastplotlib/layouts/_subplot.py

Copy file name to clipboardExpand all lines: fastplotlib/layouts/_subplot.py
-3Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,6 @@ def add_graphic(self, graphic, center: bool = True):
255255
graphic.world_object.position.z = len(self._graphics)
256256
super(Subplot, self).add_graphic(graphic, center)
257257

258-
if isinstance(graphic, graphics.HeatmapGraphic):
259-
self.controller.scale.y = copysign(self.controller.scale.y, -1)
260-
261258
def set_axes_visibility(self, visible: bool):
262259
"""Toggles axes visibility."""
263260
if visible:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.