diff --git a/docs/source/conf.py b/docs/source/conf.py index 8d17c97ae..3cf2b4e75 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -56,6 +56,7 @@ "subsection_order": ExplicitOrder( [ "../../examples/image", + "../../examples/image_volume", "../../examples/heatmap", "../../examples/image_widget", "../../examples/gridplot", diff --git a/examples/image_volume/README.rst b/examples/image_volume/README.rst new file mode 100644 index 000000000..6c349ebfa --- /dev/null +++ b/examples/image_volume/README.rst @@ -0,0 +1,2 @@ +Image Volume Examples +===================== diff --git a/examples/image_volume/image_volume_4d.py b/examples/image_volume/image_volume_4d.py new file mode 100644 index 000000000..208c3a97b --- /dev/null +++ b/examples/image_volume/image_volume_4d.py @@ -0,0 +1,83 @@ +""" +Volume movie +============ + +""" + +# test_example = false +# sphinx_gallery_pygfx_docs = 'screenshot' + +import numpy as np +from scipy.ndimage import gaussian_filter +import fastplotlib as fpl +from tqdm import tqdm + + +def gen_data(p=1, noise=.5, T=256, framerate=30, firerate=2.,): + if p == 2: + gamma = np.array([1.5, -.55]) + elif p == 1: + gamma = np.array([.9]) + else: + raise + dims = (128, 128, 30) # size of image + sig = (4, 4, 2) # neurons size + bkgrd = 10 + N = 150 # number of neurons + np.random.seed(0) + centers = np.asarray([[np.random.randint(s, x - s) + for x, s in zip(dims, sig)] for i in range(N)]) + Y = np.zeros((T,) + dims, dtype=np.float32) + trueSpikes = np.random.rand(N, T) < firerate / float(framerate) + trueSpikes[:, 0] = 0 + truth = trueSpikes.astype(np.float32) + for i in tqdm(range(2, T)): + if p == 2: + truth[:, i] += gamma[0] * truth[:, i - 1] + gamma[1] * truth[:, i - 2] + else: + truth[:, i] += gamma[0] * truth[:, i - 1] + for i in tqdm(range(N)): + Y[:, centers[i, 0], centers[i, 1], centers[i, 2]] = truth[i] + tmp = np.zeros(dims) + tmp[tuple(np.array(dims)//2)] = 1. + print("gaussing filtering") + z = np.linalg.norm(gaussian_filter(tmp, sig).ravel()) + + print("finishing") + Y = bkgrd + noise * np.random.randn(*Y.shape) + 10 * gaussian_filter(Y, (0,) + sig) / z + + return Y + + +voldata = gen_data() + +fig = fpl.Figure(cameras="3d", controller_types="orbit", size=(700, 560)) + +vmin, vmax = fpl.utils.quick_min_max(voldata) + +volume = fig[0, 0].add_image_volume(voldata[0], vmin=vmin, vmax=vmax, interpolation="linear", cmap="gnuplot2") + +hlut = fpl.HistogramLUTTool(voldata, volume) + +fig[0, 0].docks["right"].size = 100 +fig[0, 0].docks["right"].controller.enabled = False +fig[0, 0].docks["right"].add_graphic(hlut) +fig[0, 0].docks["right"].auto_scale(maintain_aspect=False) + +fig.show() + + +i = 0 +def update(): + global i + + volume.data = voldata[i] + + i += 1 + if i == voldata.shape[0]: + i = 0 + + +fig.add_animations(update) + +fpl.loop.run() diff --git a/examples/image_volume/image_volume_ray.py b/examples/image_volume/image_volume_ray.py new file mode 100644 index 000000000..f16a08803 --- /dev/null +++ b/examples/image_volume/image_volume_ray.py @@ -0,0 +1,24 @@ +""" +Volume Ray mode +=============== + +View a volume, uses the fly controller by default so you can fly around the scene using WASD keys and the mouse: +https://docs.pygfx.org/stable/_autosummary/controllers/pygfx.controllers.FlyController.html#pygfx.controllers.FlyController +""" + +# test_example = false +# sphinx_gallery_pygfx_docs = 'screenshot' + +import numpy as np +import fastplotlib as fpl +import imageio.v3 as iio + +voldata = iio.imread("imageio:stent.npz").astype(np.float32) + +fig = fpl.Figure(cameras="3d", size=(700, 560)) + +fig[0, 0].add_image_volume(voldata) + +fig.show() + +fpl.loop.run() diff --git a/examples/tests/testutils.py b/examples/tests/testutils.py index 4c23b3481..546ff120e 100644 --- a/examples/tests/testutils.py +++ b/examples/tests/testutils.py @@ -18,6 +18,7 @@ # examples live in themed sub-folders example_globs = [ "image/*.py", + "image_volume/*.py", "image_widget/*.py", "heatmap/*.py", "scatter/*.py", diff --git a/fastplotlib/graphics/__init__.py b/fastplotlib/graphics/__init__.py index 03f361502..57058fd9c 100644 --- a/fastplotlib/graphics/__init__.py +++ b/fastplotlib/graphics/__init__.py @@ -2,6 +2,7 @@ from .line import LineGraphic from .scatter import ScatterGraphic from .image import ImageGraphic +from .image_volume import ImageVolumeGraphic from .text import TextGraphic from .line_collection import LineCollection, LineStack @@ -10,6 +11,7 @@ "LineGraphic", "ScatterGraphic", "ImageGraphic", + "ImageVolumeGraphic", "TextGraphic", "LineCollection", "LineStack", diff --git a/fastplotlib/graphics/_base.py b/fastplotlib/graphics/_base.py index e115107b0..924f35164 100644 --- a/fastplotlib/graphics/_base.py +++ b/fastplotlib/graphics/_base.py @@ -53,13 +53,6 @@ class Graphic: _features: dict[str, type] = dict() def __init_subclass__(cls, **kwargs): - # set the type of the graphic in lower case like "image", "line_collection", etc. - cls.type = ( - cls.__name__.lower() - .replace("graphic", "") - .replace("collection", "_collection") - .replace("stack", "_stack") - ) # set of all features cls._features = { diff --git a/fastplotlib/graphics/features/_image.py b/fastplotlib/graphics/features/_image.py index c47a26e6a..a6e3665a9 100644 --- a/fastplotlib/graphics/features/_image.py +++ b/fastplotlib/graphics/features/_image.py @@ -13,8 +13,13 @@ ) -# manages an array of 8192x8192 Textures representing chunks of an image class TextureArray(GraphicFeature): + """ + Manages an array of Textures representing chunks of an image. + + Creates multiple pygfx.Texture objects based on the GPU's max texture dimension limit. + """ + event_info_spec = [ { "dict key": "key", @@ -28,13 +33,30 @@ class TextureArray(GraphicFeature): }, ] - def __init__(self, data, isolated_buffer: bool = True): + def __init__(self, data, dim: int, isolated_buffer: bool = True): + """ + + Parameters + ---------- + dim: int, 2 | 3 + whether the data array represents a 2D or 3D texture + + """ + if dim not in (2, 3): + raise ValueError("`dim` must be 2 | 3") + + self._dim = dim + super().__init__() data = self._fix_data(data) shared = pygfx.renderers.wgpu.get_shared() - self._texture_limit_2d = shared.device.limits["max-texture-dimension-2d"] + + if self._dim == 2: + self._texture_size_limit = shared.device.limits["max-texture-dimension-2d"] + else: + self._texture_size_limit = shared.device.limits["max-texture-dimension-3d"] if isolated_buffer: # useful if data is read-only, example: memmaps @@ -47,26 +69,39 @@ def __init__(self, data, isolated_buffer: bool = True): # data start indices for each Texture self._row_indices = np.arange( 0, - ceil(self.value.shape[0] / self._texture_limit_2d) * self._texture_limit_2d, - self._texture_limit_2d, + ceil(self.value.shape[0] / self._texture_size_limit) + * self._texture_size_limit, + self._texture_size_limit, ) self._col_indices = np.arange( 0, - ceil(self.value.shape[1] / self._texture_limit_2d) * self._texture_limit_2d, - self._texture_limit_2d, + ceil(self.value.shape[1] / self._texture_size_limit) + * self._texture_size_limit, + self._texture_size_limit, ) + shape = [self.row_indices.size, self.col_indices.size] + + if self._dim == 3: + self._zdim_indices = np.arange( + 0, + ceil(self.value.shape[2] / self._texture_size_limit) + * self._texture_size_limit, + self._texture_size_limit, + ) + shape += [self.zdim_indices.size] + else: + self._zdim_indices = np.empty(0) + # buffer will be an array of textures - self._buffer: np.ndarray[pygfx.Texture] = np.empty( - shape=(self.row_indices.size, self.col_indices.size), dtype=object - ) + self._buffer: np.ndarray[pygfx.Texture] = np.empty(shape=shape, dtype=object) self._iter = None # iterate through each chunk of passed `data` # create a pygfx.Texture from this chunk for _, buffer_index, data_slice in self: - texture = pygfx.Texture(self.value[data_slice], dim=2) + texture = pygfx.Texture(self.value[data_slice], dim=self._dim) self.buffer[buffer_index] = texture @@ -99,6 +134,10 @@ def col_indices(self) -> np.ndarray: """ return self._col_indices + @property + def zdim_indices(self) -> np.ndarray: + return self._zdim_indices + @property def shared(self) -> int: return self._shared @@ -114,7 +153,17 @@ def _fix_data(self, data): return data.astype(np.float32) def __iter__(self): - self._iter = product(enumerate(self.row_indices), enumerate(self.col_indices)) + if self._dim == 2: + self._iter = product( + enumerate(self.row_indices), enumerate(self.col_indices) + ) + elif self._dim == 3: + self._iter = product( + enumerate(self.row_indices), + enumerate(self.col_indices), + enumerate(self.zdim_indices), + ) + return self def __next__(self) -> tuple[pygfx.Texture, tuple[int, int], tuple[slice, slice]]: @@ -128,22 +177,36 @@ def __next__(self) -> tuple[pygfx.Texture, tuple[int, int], tuple[slice, slice]] | tuple[int, int]: chunk index, i.e corresponding index of ``self.buffer`` array | tuple[slice, slice]: data slice of big array in this chunk and Texture """ - (chunk_row, data_row_start), (chunk_col, data_col_start) = next(self._iter) + if self._dim == 2: + (chunk_row, data_row_start), (chunk_col, data_col_start) = next(self._iter) + elif self._dim == 3: + ( + (chunk_row, data_row_start), + (chunk_col, data_col_start), + (chunk_z, data_z_start), + ) = next(self._iter) # indices for to self.buffer for this chunk - chunk_index = (chunk_row, chunk_col) + chunk_index = [chunk_row, chunk_col] + + if self._dim == 3: + chunk_index += [chunk_z] # stop indices of big data array for this chunk - row_stop = min(self.value.shape[0], data_row_start + self._texture_limit_2d) - col_stop = min(self.value.shape[1], data_col_start + self._texture_limit_2d) + row_stop = min(self.value.shape[0], data_row_start + self._texture_size_limit) + col_stop = min(self.value.shape[1], data_col_start + self._texture_size_limit) + if self._dim == 3: + z_stop = min(self.value.shape[2], data_z_start + self._texture_size_limit) # row and column slices that slice the data for this chunk from the big data array - data_slice = (slice(data_row_start, row_stop), slice(data_col_start, col_stop)) + data_slice = [slice(data_row_start, row_stop), slice(data_col_start, col_stop)] + if self._dim == 3: + data_slice += [slice(data_z_start, z_stop)] # texture for this chunk - texture = self.buffer[chunk_index] + texture = self.buffer[tuple(chunk_index)] - return texture, chunk_index, data_slice + return texture, chunk_index, tuple(data_slice) def __getitem__(self, item): return self.value[item] diff --git a/fastplotlib/graphics/image.py b/fastplotlib/graphics/image.py index 5f198c84f..58d64768b 100644 --- a/fastplotlib/graphics/image.py +++ b/fastplotlib/graphics/image.py @@ -101,10 +101,10 @@ def __init__( | shape must be ``[n_rows, n_cols]``, ``[n_rows, n_cols, 3]`` for RGB or ``[n_rows, n_cols, 4]`` for RGBA vmin: int, optional - minimum value for color scaling, calculated from data if not provided + minimum value for color scaling, estimated from data if not provided vmax: int, optional - maximum value for color scaling, calculated from data if not provided + maximum value for color scaling, estimated from data if not provided cmap: str, optional, default "plasma" colormap to use to display the data @@ -129,8 +129,8 @@ def __init__( world_object = pygfx.Group() - # texture array that manages the textures on the GPU for displaying this image - self._data = TextureArray(data, isolated_buffer=isolated_buffer) + # texture array that manages the multiple textures on the GPU that represent this image + self._data = TextureArray(data, dim=2, isolated_buffer=isolated_buffer) if (vmin is None) or (vmax is None): vmin, vmax = quick_min_max(data) @@ -165,7 +165,7 @@ def __init__( ) # iterate through each texture chunk and create - # an _ImageTIle, offset the tile using the data indices + # an _ImageTile, offset the tile using the data indices for texture, chunk_index, data_slice in self._data: # create an ImageTile using the texture for this chunk diff --git a/fastplotlib/graphics/image_volume.py b/fastplotlib/graphics/image_volume.py new file mode 100644 index 000000000..0ca5697c1 --- /dev/null +++ b/fastplotlib/graphics/image_volume.py @@ -0,0 +1,231 @@ +from typing import * + +import pygfx + +from ..utils import quick_min_max +from ._base import Graphic +from .features import ( + TextureArray, + ImageCmap, + ImageVmin, + ImageVmax, + ImageInterpolation, + ImageCmapInterpolation, +) + + +class _VolumeTile(pygfx.Volume): + """ + Similar to pygfx.Volume, only difference is that it modifies the pick_info + by adding the data row start indices that correspond to this chunk of the big Volume + """ + + def __init__( + self, + geometry, + material, + data_slice: tuple[slice, slice, slice], + chunk_index: tuple[int, int, int], + **kwargs, + ): + super().__init__(geometry, material, **kwargs) + + self._data_slice = data_slice + self._chunk_index = chunk_index + + def _wgpu_get_pick_info(self, pick_value): + pick_info = super()._wgpu_get_pick_info(pick_value) + + data_row_start, data_col_start, data_z_start = ( + self.data_slice[0].start, + self.data_slice[1].start, + self.data_slice[2].start, + ) + + # add the actual data row and col start indices + x, y, z = pick_info["index"] + x += data_col_start + y += data_row_start + z += data_z_start + pick_info["index"] = (x, y, z) + + xp, yp, zp = pick_info["voxel_coord"] + xp += data_col_start + yp += data_row_start + zp += data_z_start + pick_info["voxel_coord"] = (xp, yp, zp) + + # add row chunk and col chunk index to pick_info dict + return { + **pick_info, + "data_slice": self.data_slice, + "chunk_index": self.chunk_index, + } + + @property + def data_slice(self) -> tuple[slice, slice, slice]: + return self._data_slice + + @property + def chunk_index(self) -> tuple[int, int, int]: + return self._chunk_index + + +class ImageVolumeGraphic(Graphic): + _features = { + "data": TextureArray, + "cmap": ImageCmap, + "vmin": ImageVmin, + "vmax": ImageVmax, + "interpolation": ImageInterpolation, + "cmap_interpolation": ImageCmapInterpolation, + } + + def __init__( + self, + data: Any, + mode: str = "ray", + vmin: int = None, + vmax: int = None, + cmap: str = "plasma", + interpolation: str = "nearest", + cmap_interpolation: str = "linear", + isolated_buffer: bool = True, + **kwargs, + ): + valid_modes = ["basic", "ray", "slice", "iso", "mip", "minip"] + if mode not in valid_modes: + raise ValueError( + f"invalid mode specified: {mode}, valid modes are: {valid_modes}" + ) + + super().__init__(**kwargs) + + world_object = pygfx.Group() + + # texture array that manages the textures on the GPU that represent this image volume + self._data = TextureArray(data, dim=3, isolated_buffer=isolated_buffer) + + if (vmin is None) or (vmax is None): + vmin, vmax = quick_min_max(data) + + # other graphic features + self._vmin = ImageVmin(vmin) + self._vmax = ImageVmax(vmax) + + self._interpolation = ImageInterpolation(interpolation) + + # TODO: I'm assuming RGB volume images aren't supported??? + # use TextureMap for grayscale images + self._cmap = ImageCmap(cmap) + self._cmap_interpolation = ImageCmapInterpolation(cmap_interpolation) + + _map = pygfx.TextureMap( + self._cmap.texture, + filter=self._cmap_interpolation.value, + wrap="clamp-to-edge", + ) + + material_cls = getattr(pygfx, f"Volume{mode.capitalize()}Material") + + # TODO: graphic features for the various material properties + self._material = material_cls( + clim=(self._vmin.value, self._vmax.value), + map=_map, + interpolation=self._interpolation.value, + pick_write=True, + ) + + # iterate through each texture chunk and create + # a _VolumeTile, offset the tile using the data indices + for texture, chunk_index, data_slice in self._data: + # create a _VolumeTile using the texture for this chunk + vol = _VolumeTile( + geometry=pygfx.Geometry(grid=texture), + material=self._material, + data_slice=data_slice, # used to parse pick_info + chunk_index=chunk_index, + ) + + # row and column start index for this chunk + data_row_start = data_slice[0].start + data_col_start = data_slice[1].start + data_z_start = data_slice[2].start + + # offset tile position using the indices from the big data array + # that correspond to this chunk + vol.world.x = data_col_start + vol.world.y = data_row_start + vol.world.z = data_z_start + + world_object.add(vol) + + self._set_world_object(world_object) + + @property + def data(self) -> TextureArray: + """Get or set the image data""" + return self._data + + @data.setter + def data(self, data): + self._data[:] = data + + @property + def cmap(self) -> str: + """colormap name""" + return self._cmap.value + + @cmap.setter + def cmap(self, name: str): + self._cmap.set_value(self, name) + + @property + def vmin(self) -> float: + """lower contrast limit""" + return self._vmin.value + + @vmin.setter + def vmin(self, value: float): + self._vmin.set_value(self, value) + + @property + def vmax(self) -> float: + """upper contrast limit""" + return self._vmax.value + + @vmax.setter + def vmax(self, value: float): + self._vmax.set_value(self, value) + + @property + def interpolation(self) -> str: + """image data interpolation method""" + return self._interpolation.value + + @interpolation.setter + def interpolation(self, value: str): + self._interpolation.set_value(self, value) + + @property + def cmap_interpolation(self) -> str: + """cmap interpolation method""" + return self._cmap_interpolation.value + + @cmap_interpolation.setter + def cmap_interpolation(self, value: str): + self._cmap_interpolation.set_value(self, value) + + def reset_vmin_vmax(self): + """ + Reset the vmin, vmax by estimating it from the data + + Returns + ------- + None + + """ + + vmin, vmax = quick_min_max(self._data.value) + self.vmin = vmin + self.vmax = vmax diff --git a/fastplotlib/layouts/_graphic_methods_mixin.py b/fastplotlib/layouts/_graphic_methods_mixin.py index a753eec73..9c14498b1 100644 --- a/fastplotlib/layouts/_graphic_methods_mixin.py +++ b/fastplotlib/layouts/_graphic_methods_mixin.py @@ -45,10 +45,10 @@ def add_image( | shape must be ``[n_rows, n_cols]``, ``[n_rows, n_cols, 3]`` for RGB or ``[n_rows, n_cols, 4]`` for RGBA vmin: int, optional - minimum value for color scaling, calculated from data if not provided + minimum value for color scaling, estimated from data if not provided vmax: int, optional - maximum value for color scaling, calculated from data if not provided + maximum value for color scaling, estimated from data if not provided cmap: str, optional, default "plasma" colormap to use to display the data @@ -81,6 +81,34 @@ def add_image( **kwargs, ) + def add_image_volume( + self, + data: Any, + mode: str = "ray", + vmin: int = None, + vmax: int = None, + cmap: str = "plasma", + interpolation: str = "nearest", + cmap_interpolation: str = "linear", + isolated_buffer: bool = True, + **kwargs, + ) -> ImageVolumeGraphic: + """ + None + """ + return self._create_graphic( + ImageVolumeGraphic, + data, + mode, + vmin, + vmax, + cmap, + interpolation, + cmap_interpolation, + isolated_buffer, + **kwargs, + ) + def add_line_collection( self, data: Union[numpy.ndarray, List[numpy.ndarray]], diff --git a/fastplotlib/utils/functions.py b/fastplotlib/utils/functions.py index e775288d3..b276ea98b 100644 --- a/fastplotlib/utils/functions.py +++ b/fastplotlib/utils/functions.py @@ -269,20 +269,21 @@ def make_colors_dict(labels: Sequence, cmap: str, **kwargs) -> OrderedDict: def quick_min_max(data: np.ndarray, max_size=1e6) -> tuple[float, float]: """ - Adapted from pyqtgraph.ImageView. - Estimate the min/max values of *data* by subsampling. + Estimate the (min, max) values of data array by subsampling. + + Also supports array-like data types may have a `min` and `max` property that provides a pre-calculated (min, max). Parameters ---------- - data: np.ndarray or array-like with `min` and `max` attributes + data: np.ndarray or array-like max_size : int, optional - largest array size allowed in the subsampled array. Default is 1e6. + subsamples data array to this max size Returns ------- (float, float) - (min, max) + (min, max) estimate """ if hasattr(data, "min") and hasattr(data, "max"): diff --git a/scripts/generate_add_graphic_methods.py b/scripts/generate_add_graphic_methods.py index 533ae77c6..968c68d2a 100644 --- a/scripts/generate_add_graphic_methods.py +++ b/scripts/generate_add_graphic_methods.py @@ -1,5 +1,6 @@ import inspect import pathlib +import re import black @@ -19,6 +20,8 @@ for name, obj in inspect.getmembers(graphics): if inspect.isclass(obj): + if obj.__name__ == "Graphic": + continue # skip the base class modules.append(obj) @@ -49,23 +52,25 @@ def generate_add_graphics_methods(): f.write(" return graphic\n\n") for m in modules: - class_name = m - method_name = class_name.type + cls = m + cls_name = cls.__name__.replace("Graphic", "") + # from https://stackoverflow.com/a/1176023 + method_name = re.sub(r'(? {class_name.__name__}:\n" + f" def add_{method_name}{inspect.signature(cls.__init__)} -> {cls.__name__}:\n" ) f.write(' """\n') - f.write(f" {class_name.__init__.__doc__}\n") + f.write(f" {cls.__init__.__doc__}\n") f.write(' """\n') f.write( - f" return self._create_graphic({class_name.__name__}, {s} **kwargs)\n\n" + f" return self._create_graphic({cls.__name__}, {s} **kwargs)\n\n" ) f.close()