Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
ColorFeature working, updated graphics and base Graphic color handlin…
…g, updated examples
  • Loading branch information
kushalkolar committed Dec 19, 2022
commit 5ceb39764c3aece252a1a1ae2cd7dfcfa88b8b18
16 changes: 8 additions & 8 deletions 16 examples/scatter.ipynb

Large diffs are not rendered by default.

81 changes: 50 additions & 31 deletions 81 examples/simple.ipynb

Large diffs are not rendered by default.

57 changes: 24 additions & 33 deletions 57 fastplotlib/graphics/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,56 +4,48 @@
import pygfx

from fastplotlib.utils import get_colors, map_labels_to_colors
from ._graphic_attribute import ColorFeature


class Graphic:
def __init__(
self,
data,
colors: np.ndarray = None,
colors_length: int = None,
colors: Any = False,
n_colors: int = None,
cmap: str = None,
alpha: float = 1.0,
name: str = None
):
"""

Parameters
----------
data
colors: Any
if ``False``, no color generation is performed, cmap is also ignored.
n_colors
cmap
alpha
name
"""
self.data = data.astype(np.float32)
self.colors = None

self.name = name

# if colors_length is None:
# colors_length = self.data.shape[0]
if n_colors is None:
n_colors = self.data.shape[0]

if colors is not False:
self._set_colors(colors, colors_length, cmap, alpha, )

def _set_colors(self, colors, colors_length, cmap, alpha):
if colors_length is None:
colors_length = self.data.shape[0]

if colors is None and cmap is None: # just white
self.colors = np.vstack([[1., 1., 1., 1.]] * colors_length).astype(np.float32)

elif (colors is None) and (cmap is not None):
self.colors = get_colors(n_colors=colors_length, cmap=cmap, alpha=alpha)

elif (colors is not None) and (cmap is None):
# assume it's already an RGBA array
colors = np.array(colors)
if colors.shape == (1, 4) or colors.shape == (4,):
self.colors = np.vstack([colors] * colors_length).astype(np.float32)
elif colors.ndim == 2 and colors.shape[1] == 4 and colors.shape[0] == colors_length:
self.colors = colors.astype(np.float32)
else:
raise ValueError(f"Colors array must have ndim == 2 and shape of [<n_datapoints>, 4]")
if cmap is not None and colors is not False:
colors = get_colors(n_colors=n_colors, cmap=cmap, alpha=alpha)

elif (colors is not None) and (cmap is not None):
if colors.ndim == 1 and np.issubdtype(colors.dtype, np.integer):
# assume it's a mapping of colors
self.colors = np.array(map_labels_to_colors(colors, cmap, alpha=alpha)).astype(np.float32)
if colors is not False:
self.colors = ColorFeature(parent=self, colors=colors, n_colors=n_colors, alpha=alpha)

else:
raise ValueError("Unknown color format")
@property
def world_object(self) -> pygfx.WorldObject:
return self._world_object

@property
def children(self) -> pygfx.WorldObject:
Expand All @@ -67,4 +59,3 @@ def __repr__(self):
return f"'{self.name}' fastplotlib.{self.__class__.__name__} @ {hex(id(self))}"
else:
return f"fastplotlib.{self.__class__.__name__} @ {hex(id(self))}"

8 changes: 6 additions & 2 deletions 8 fastplotlib/graphics/_graphic_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def __init__(self, parent, colors, n_colors, alpha: float = 1.0):

n_colors: number of colors to hold, if passing in a single str or single RGBA array
"""
# if provided as a numpy array of str
if isinstance(colors, np.ndarray):
if colors.dtype.kind in ["U", "S"]:
colors = colors.tolist()
# if the color is provided as a numpy array
if isinstance(colors, np.ndarray):
if colors.shape == (4,): # single RGBA array
Expand All @@ -97,10 +101,10 @@ def __init__(self, parent, colors, n_colors, alpha: float = 1.0):
)

# if the color is provided as an iterable
elif isinstance(colors, (list, tuple)):
elif isinstance(colors, (list, tuple, np.ndarray)):
# if iterable of str
if all([isinstance(val, str) for val in colors]):
if not len(list) == n_colors:
if not len(colors) == n_colors:
raise ValueError(
f"Valid iterable color arguments must be a `tuple` or `list` of `str` "
f"where the length of the iterable is the same as the number of datapoints."
Expand Down
6 changes: 3 additions & 3 deletions 6 fastplotlib/graphics/histogram.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from _warnings import warn
from warnings import warn
from typing import Union, Dict

import numpy as np
Expand Down Expand Up @@ -82,9 +82,9 @@ def __init__(

data = np.vstack([x_positions_bins, self.hist])

super(HistogramGraphic, self).__init__(data=data, colors=colors, colors_length=n_bins, **kwargs)
super(HistogramGraphic, self).__init__(data=data, colors=colors, n_colors=n_bins, **kwargs)

self.world_object: pygfx.Group = pygfx.Group()
self._world_object: pygfx.Group = pygfx.Group()

for x_val, y_val, bin_center in zip(x_positions_bins, self.hist, self.bin_centers):
geometry = pygfx.plane_geometry(
Expand Down
2 changes: 1 addition & 1 deletion 2 fastplotlib/graphics/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
if (vmin is None) or (vmax is None):
vmin, vmax = quick_min_max(data)

self.world_object: pygfx.Image = pygfx.Image(
self._world_object: pygfx.Image = pygfx.Image(
pygfx.Geometry(grid=pygfx.Texture(self.data, dim=2)),
pygfx.ImageBasicMaterial(clim=(vmin, vmax), map=get_cmap_texture(cmap))
)
Expand Down
20 changes: 8 additions & 12 deletions 20 fastplotlib/graphics/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(
data: Any,
z_position: float = 0.0,
size: float = 2.0,
colors: np.ndarray = None,
colors: Union[str, np.ndarray, Iterable] = "w",
cmap: str = None,
*args,
**kwargs
Expand All @@ -30,7 +30,9 @@ def __init__(
size: float, optional
thickness of the line

colors:
colors: str, array, or iterable
specify colors as a single human readable string, a single RGBA array,
or an iterable of strings or RGBA arrays

cmap: str, optional
apply a colormap to the line instead of assigning colors manually
Expand All @@ -51,8 +53,8 @@ def __init__(

self.data = np.ascontiguousarray(self.data)

self.world_object: pygfx.Line = pygfx.Line(
geometry=pygfx.Geometry(positions=self.data, colors=self.colors),
self._world_object: pygfx.Line = pygfx.Line(
geometry=pygfx.Geometry(positions=self.data, colors=self.colors.data),
material=material(thickness=size, vertex_colors=True)
)

Expand All @@ -61,7 +63,7 @@ def __init__(
def fix_data(self):
# TODO: data should probably be a property of any Graphic?? Or use set_data() and get_data()
if self.data.ndim == 1:
self.data = np.dstack([np.arange(self.data.size), self.data])[0]
self.data = np.dstack([np.arange(self.data.size), self.data])[0].astype(np.float32)

if self.data.shape[1] != 3:
if self.data.shape[1] != 2:
Expand All @@ -70,17 +72,11 @@ def fix_data(self):
# zeros for z
zs = np.zeros(self.data.shape[0], dtype=np.float32)

self.data = np.dstack([self.data[:, 0], self.data[:, 1], zs])[0]
self.data = np.dstack([self.data[:, 0], self.data[:, 1], zs])[0].astype(np.float32)

def update_data(self, data: np.ndarray):
self.data = data.astype(np.float32)
self.fix_data()

self.world_object.geometry.positions.data[:] = self.data
self.world_object.geometry.positions.update_range()

def update_colors(self, colors: np.ndarray):
super(LineGraphic, self)._set_colors(colors=colors, colors_length=self.data.shape[0], cmap=None, alpha=None)

self.world_object.geometry.colors.data[:] = self.colors
self.world_object.geometry.colors.update_range()
56 changes: 22 additions & 34 deletions 56 fastplotlib/graphics/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,37 @@


class ScatterGraphic(Graphic):
def __init__(self, data: np.ndarray, z_position: float = 0.0, size: int = 1, colors: np.ndarray = None, cmap: str = None, *args, **kwargs):
def __init__(self, data: np.ndarray, z_position: float = 0.0, size: int = 1, colors: np.ndarray = "w", cmap: str = None, *args, **kwargs):
super(ScatterGraphic, self).__init__(data, colors=colors, cmap=cmap, *args, **kwargs)

if self.data.ndim == 1:
# assume single 3D point
if not self.data.size == 3:
raise ValueError("If passing single you must specify all coordinates, i.e. x, y and z.")
elif self.data.shape[1] != 3:
if self.data.shape[1] == 2:

# zeros
zs = np.zeros(self.data.shape[0], dtype=np.float32)

self.data = np.dstack([self.data[:, 0], self.data[:, 1], zs])[0]
if self.data.shape[1] > 3 or self.data.shape[1] < 1:
raise ValueError("Must pass 2D or 3D data or a single point")
self.fix_data()

self.world_object: pygfx.Group = pygfx.Group()
self.points_objects: List[pygfx.Points] = list()
sizes = np.full(self.data.shape[0], size, dtype=np.float32)

for color in np.unique(self.colors, axis=0):
positions = self._process_positions(
self.data[np.all(self.colors == color, axis=1)]
)
self._world_object: pygfx.Points = pygfx.Points(
pygfx.Geometry(positions=self.data, sizes=sizes, colors=self.colors.data),
material=pygfx.PointsMaterial(vertex_colors=True, vertex_sizes=True)
)

points = pygfx.Points(
pygfx.Geometry(positions=positions),
pygfx.PointsMaterial(size=size, color=color)
)
self.world_object.position.z = z_position

self.world_object.add(points)
self.points_objects.append(points)
def fix_data(self):
# TODO: data should probably be a property of any Graphic?? Or use set_data() and get_data()
if self.data.ndim == 1:
self.data = np.array([self.data])

self.world_object.position.z = z_position
if self.data.shape[1] != 3:
if self.data.shape[1] != 2:
raise ValueError("Must pass 1D, 2D or 3D data")

def _process_positions(self, positions: np.ndarray):
if positions.ndim == 1:
positions = np.array([positions])
# zeros for z
zs = np.zeros(self.data.shape[0], dtype=np.float32)

return positions
self.data = np.dstack([self.data[:, 0], self.data[:, 1], zs])[0].astype(np.float32)

def update_data(self, data: np.ndarray):
positions = self._process_positions(data).astype(np.float32)
self.data = data
self.fix_data()

self.points_objects[0].geometry.positions.data[:] = positions
self.points_objects[0].geometry.positions.update_range(positions.shape[0])
self.world_object.geometry.positions.data[:] = self.data
self.world_object.geometry.positions.update_range(self.data.shape[0])
8 changes: 1 addition & 7 deletions 8 fastplotlib/utils/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,7 @@ def _get_cmap(name: str, alpha: float = 1.0) -> np.ndarray:
return cmap.astype(np.float32)


def get_colors(
n_colors: int,
cmap: str,
spacing: str = 'uniform',
alpha: float = 1.0
) \
-> List[Union[np.ndarray, str]]:
def get_colors(n_colors: int, cmap: str, alpha: float = 1.0) -> np.ndarray:
cmap = _get_cmap(cmap, alpha)
cm_ixs = np.linspace(0, 255, n_colors, dtype=int)
return np.take(cmap, cm_ixs, axis=0).astype(np.float32)
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.