Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

HeatmapGraphic, supports dims larger than 8192 #143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions 4 fastplotlib/graphics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .histogram import HistogramGraphic
from .line import LineGraphic
from .scatter import ScatterGraphic
from .image import ImageGraphic
from .heatmap import HeatmapGraphic
from .image import ImageGraphic, HeatmapGraphic
# from .heatmap import HeatmapGraphic
from .text import TextGraphic
from .line_collection import LineCollection, LineStack

Expand Down
4 changes: 2 additions & 2 deletions 4 fastplotlib/graphics/features/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._colors import ColorFeature, CmapFeature, ImageCmapFeature
from ._data import PointsDataFeature, ImageDataFeature
from ._colors import ColorFeature, CmapFeature, ImageCmapFeature, HeatmapCmapFeature
from ._data import PointsDataFeature, ImageDataFeature, HeatmapDataFeature
from ._present import PresentFeature
from ._thickness import ThicknessFeature
from ._base import GraphicFeature, GraphicFeatureIndexable
30 changes: 29 additions & 1 deletion 30 fastplotlib/graphics/features/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,33 @@
from pygfx import Buffer


supported_dtypes = [
np.uint8,
np.uint16,
np.uint32,
np.int8,
np.int16,
np.int32,
np.float16,
np.float32
]


def to_gpu_supported_dtype(array):
if isinstance(array, np.ndarray):
if array.dtype not in supported_dtypes:
if np.issubdtype(array.dtype, np.integer):
warn(f"converting {array.dtype} array to int32")
return array.astype(np.int32)
elif np.issubdtype(array.dtype, np.floating):
warn(f"converting {array.dtype} array to float32")
return array.astype(np.float32, copy=False)
else:
raise TypeError("Unsupported type, supported array types must be int or float dtypes")

return array


class FeatureEvent:
"""
type: <feature_name>, example: "colors"
Expand Down Expand Up @@ -43,7 +70,7 @@ def __init__(self, parent, data: Any, collection_index: int = None):
"""
self._parent = parent
if isinstance(data, np.ndarray):
data = data.astype(np.float32)
data = to_gpu_supported_dtype(data)

self._data = data

Expand Down Expand Up @@ -227,3 +254,4 @@ def _update_range_indices(self, key):
self._buffer.update_range(ix, size=1)
else:
raise TypeError("must pass int or slice to update range")

13 changes: 13 additions & 0 deletions 13 fastplotlib/graphics/features/_colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,16 @@ def _feature_changed(self, key, new_data):
event_data = FeatureEvent(type="cmap", pick_info=pick_info)

self._call_event_handlers(event_data)


class HeatmapCmapFeature(ImageCmapFeature):
"""
Colormap for HeatmapGraphic
"""

def _set(self, cmap_name: str):
self._parent._material.map.texture.data[:] = make_colors(256, cmap_name)
self._parent._material.map.texture.update_range((0, 0, 0), size=(256, 1, 1))
self.name = cmap_name

self._feature_changed(key=None, new_data=self.name)
57 changes: 47 additions & 10 deletions 57 fastplotlib/graphics/features/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,7 @@
import numpy as np
from pygfx import Buffer, Texture

from ._base import GraphicFeatureIndexable, cleanup_slice, FeatureEvent


def to_float32(array):
if isinstance(array, np.ndarray):
return array.astype(np.float32, copy=False)

return array
from ._base import GraphicFeatureIndexable, cleanup_slice, FeatureEvent, to_gpu_supported_dtype


class PointsDataFeature(GraphicFeatureIndexable):
Expand Down Expand Up @@ -102,7 +95,7 @@ def __init__(self, parent, data: Any):
"``[x_dim, y_dim]`` or ``[x_dim, y_dim, rgb]``"
)

data = to_float32(data)
data = to_gpu_supported_dtype(data)
super(ImageDataFeature, self).__init__(parent, data)

@property
Expand All @@ -114,7 +107,7 @@ def __getitem__(self, item):

def __setitem__(self, key, value):
# make sure float32
value = to_float32(value)
value = to_gpu_supported_dtype(value)

self._buffer.data[key] = value
self._update_range(key)
Expand Down Expand Up @@ -145,3 +138,47 @@ def _feature_changed(self, key, new_data):
event_data = FeatureEvent(type="data", pick_info=pick_info)

self._call_event_handlers(event_data)


class HeatmapDataFeature(ImageDataFeature):
@property
def _buffer(self) -> List[Texture]:
return [img.geometry.grid.texture for img in self._parent.world_object.children]

def __getitem__(self, item):
return self._data[item]

def __setitem__(self, key, value):
# make sure supported type, not float64 etc.
value = to_gpu_supported_dtype(value)

self._data[key] = value
self._update_range(key)

# avoid creating dicts constantly if there are no events to handle
if len(self._event_handlers) > 0:
self._feature_changed(key, value)

def _update_range(self, key):
for buffer in self._buffer:
buffer.update_range((0, 0, 0), size=buffer.size)

def _feature_changed(self, key, new_data):
if key is not None:
key = cleanup_slice(key, self._upper_bound)
if isinstance(key, int):
indices = [key]
elif isinstance(key, slice):
indices = range(key.start, key.stop, key.step)
elif key is None:
indices = None

pick_info = {
"index": indices,
"world_object": self._parent.world_object,
"new_data": new_data
}

event_data = FeatureEvent(type="data", pick_info=pick_info)

self._call_event_handlers(event_data)
176 changes: 175 additions & 1 deletion 176 fastplotlib/graphics/image.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import *
from math import ceil
from itertools import product

import pygfx
from pygfx.utils import unpack_bitfield

from ._base import Graphic, Interaction, PreviouslyModifiedData
from .features import ImageCmapFeature, ImageDataFeature
from .features import ImageCmapFeature, ImageDataFeature, HeatmapDataFeature, HeatmapCmapFeature
from ..utils import quick_min_max


Expand Down Expand Up @@ -119,5 +122,176 @@ def _reset_feature(self, feature: str):
pass


class _ImageTile(pygfx.Image):
"""
Similar to pygfx.Image, only difference is that it contains a few properties to keep track of
row chunk index, column chunk index


"""
def _wgpu_get_pick_info(self, pick_value):
tex = self.geometry.grid
if hasattr(tex, "texture"):
tex = tex.texture # tex was a view
# This should match with the shader
values = unpack_bitfield(pick_value, wobject_id=20, x=22, y=22)
x = values["x"] / 4194304 * tex.size[0] - 0.5
y = values["y"] / 4194304 * tex.size[1] - 0.5
ix, iy = int(x + 0.5), int(y + 0.5)
return {
"index": (ix, iy),
"pixel_coord": (x - ix, y - iy),
"row_chunk_index": self.row_chunk_index,
"col_chunk_index": self.col_chunk_index
}

@property
def row_chunk_index(self) -> int:
return self._row_chunk_index

@row_chunk_index.setter
def row_chunk_index(self, index: int):
self._row_chunk_index = index

@property
def col_chunk_index(self) -> int:
return self._col_chunk_index

@col_chunk_index.setter
def col_chunk_index(self, index: int):
self._col_chunk_index = index


class HeatmapGraphic(Graphic, Interaction):
feature_events = (
"data",
"cmap",
)

def __init__(
self,
data: Any,
vmin: int = None,
vmax: int = None,
cmap: str = 'plasma',
filter: str = "nearest",
chunk_size: int = 8192,
*args,
**kwargs
):
"""
Create an Image Graphic

Parameters
----------
data: array-like
array-like, usually numpy.ndarray, must support ``memoryview()``
Tensorflow Tensors also work **probably**, but not thoroughly tested
| shape must be ``[x_dim, y_dim]``
vmin: int, optional
minimum value for color scaling, calculated from data if not provided
vmax: int, optional
maximum value for color scaling, calculated from data if not provided
cmap: str, optional, default "plasma"
colormap to use to display the data
filter: str, optional, default "nearest"
interpolation filter, one of "nearest" or "linear"
chunk_size: int, default 8192, max 8192
chunk size for each tile used to make up the heatmap texture
args:
additional arguments passed to Graphic
kwargs:
additional keyword arguments passed to Graphic

Examples
--------
.. code-block:: python

from fastplotlib import Plot
# create a `Plot` instance
plot = Plot()
# make some random 2D image data
data = np.random.rand(512, 512)
# plot the image data
plot.add_image(data=data)
# show the plot
plot.show()
"""

super().__init__(*args, **kwargs)

if chunk_size > 8192:
raise ValueError("Maximum chunk size is 8192")

self.data = HeatmapDataFeature(self, data)

row_chunks = range(ceil(data.shape[0] / chunk_size))
col_chunks = range(ceil(data.shape[1] / chunk_size))

chunks = list(product(row_chunks, col_chunks))
# chunks is the index position of each chunk

start_ixs = [list(map(lambda c: c * chunk_size, chunk)) for chunk in chunks]
stop_ixs = [list(map(lambda c: c + chunk_size, chunk)) for chunk in start_ixs]

self._world_object = pygfx.Group()

if (vmin is None) or (vmax is None):
vmin, vmax = quick_min_max(data)

self.cmap = HeatmapCmapFeature(self, cmap)
self._material = pygfx.ImageBasicMaterial(clim=(vmin, vmax), map=self.cmap())

for start, stop, chunk in zip(start_ixs, stop_ixs, chunks):
row_start, col_start = start
row_stop, col_stop = stop

# x and y positions of the Tile in world space coordinates
y_pos, x_pos = row_start, col_start

tex_view = pygfx.Texture(data[row_start:row_stop, col_start:col_stop], dim=2).get_view(filter=filter)
geometry = pygfx.Geometry(grid=tex_view)
# material = pygfx.ImageBasicMaterial(clim=(0, 1), map=self.cmap())

img = _ImageTile(geometry, self._material)

# row and column chunk index for this Tile
img.row_chunk_index = chunk[0]
img.col_chunk_index = chunk[1]

img.position.set_x(x_pos)
img.position.set_y(y_pos)

self.world_object.add(img)

@property
def vmin(self) -> float:
"""Minimum contrast limit."""
return self._material.clim[0]

@vmin.setter
def vmin(self, value: float):
"""Minimum contrast limit."""
self._material.clim = (
value,
self._material.clim[1]
)

@property
def vmax(self) -> float:
"""Maximum contrast limit."""
return self._material.clim[1]

@vmax.setter
def vmax(self, value: float):
"""Maximum contrast limit."""
self._material.clim = (
self._material.clim[0],
value
)

def _set_feature(self, feature: str, new_data: Any, indices: Any):
pass

def _reset_feature(self, feature: str):
pass
3 changes: 0 additions & 3 deletions 3 fastplotlib/layouts/_subplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,6 @@ def add_graphic(self, graphic, center: bool = True):
graphic.world_object.position.z = len(self._graphics)
super(Subplot, self).add_graphic(graphic, center)

if isinstance(graphic, graphics.HeatmapGraphic):
self.controller.scale.y = copysign(self.controller.scale.y, -1)

def set_axes_visibility(self, visible: bool):
"""Toggles axes visibility."""
if visible:
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.