1
1
from ..plot import Plot
2
2
from ..layouts import GridPlot
3
3
from ..graphics import Image
4
- from ipywidgets .widgets import IntSlider , VBox , HBox
4
+ from ipywidgets .widgets import IntSlider , VBox , HBox , Layout
5
5
import numpy as np
6
6
from typing import *
7
7
from warnings import warn
@@ -24,6 +24,12 @@ def calc_gridshape(n):
24
24
)
25
25
26
26
27
+ def get_indexer (ndim : int , dim_index : int , slice_index : int ) -> slice :
28
+ dim_index = [slice (None )] * ndim
29
+ dim_index [dim_index ] = slice_index
30
+ return tuple (dim_index )
31
+
32
+
27
33
class ImageWidget :
28
34
def __init__ (
29
35
self ,
@@ -33,6 +39,7 @@ def __init__(
33
39
slider_axes : Union [int , str , dict ] = None ,
34
40
frame_apply : Union [callable , dict ] = None ,
35
41
grid_shape : Tuple [int , int ] = None ,
42
+ ** kwargs
36
43
):
37
44
# single image
38
45
if isinstance (data , np .ndarray ):
@@ -71,21 +78,85 @@ def __init__(
71
78
if axes_order is None :
72
79
self .axes_order : List [str ] = [DEFAULT_AXES_ORDER [ndim ] for i in range (len (data ))]
73
80
74
- if isinstance (slider_axes , (int )):
75
- self ._slider_axes : List [int ] = [slider_axes for i in range (len (data ))]
81
+ # if a single one is provided
82
+ if isinstance (slider_axes , (int , str )):
83
+ if isinstance (slider_axes , (int )):
84
+ self ._slider_axes = slider_axes
76
85
77
- elif isinstance (slider_axes , str ):
78
- self ._slider_axes : List [int ] = [self .axes_order .index (slider_axes )]
86
+ # also if a single one is provided, get the integer dimension index from the axes_oder string
87
+ elif isinstance (slider_axes , str ):
88
+ self ._slider_axes = self .axes_order .index (slider_axes )
79
89
80
- self .sliders : List [IntSlider ] = [
81
- IntSlider (
90
+ self .slider : IntSlider = IntSlider (
82
91
min = 0 ,
83
- max = data .shape [slider_axes ] - 1 ,
92
+ max = data .shape [self . _slider_axes ] - 1 ,
84
93
value = 0 ,
85
94
step = 1 ,
86
- description = f"slider axis: { slider_axes } "
95
+ description = f"slider axis: { self ._slider_axes } "
96
+ )
97
+
98
+ # individual slider for each data array
99
+ elif isinstance (slider_axes , dict ):
100
+ if not len (slider_axes .keys ()) == len (self .data ):
101
+ raise ValueError (
102
+ f"Must provide slider_axes entry for every input `data` array"
103
+ )
104
+
105
+ if not isinstance (axes_order , dict ):
106
+ raise ValueError ("Must pass `axes_order` dict if passing a dict of `slider_axes`" )
107
+
108
+ if not len (axes_order .keys ()) == len (self .data ):
109
+ raise ValueError (
110
+ f"Must provide `axes_order` entry for every input `data` array"
111
+ )
112
+
113
+ # convert str type desired slider axes to dimension index integers
114
+ # matchup to the given axes_order dict
115
+ _axes = [
116
+ self .axes_order [array ].index (slider_axes [array ])
117
+ if isinstance (dim_index , str )
118
+ else dim_index
119
+ for
120
+ array , dim_index in slider_axes .items ()
121
+ ]
122
+
123
+ self .sliders : Dict [IntSlider ] = {
124
+ array : IntSlider (
125
+ min = 0 ,
126
+ max = array .shape [dim ] - 1 ,
127
+ step = 1 ,
128
+ value = 0 ,
129
+ )
130
+ for array , dim in self .axes_order .items ()
131
+ }
132
+
133
+ if self .plot_type == Plot :
134
+ self .plot = Plot ()
135
+
136
+ slice_index = get_indexer (ndim , self ._slider_axes , slice_index = 0 )
137
+
138
+ self .image_graphics : List [Image ] = [self .plot .image (
139
+ data = data [0 ][slice_index ],
140
+ ** kwargs
141
+ )]
142
+
143
+ self .slider .observe (
144
+ lambda x : self .image_graphics [0 ].update_data (
145
+ data [0 ][
146
+ get_indexer (ndim , self ._slider_axes , slice_index = x ["new" ])
147
+ ]
148
+ ),
149
+ names = "value"
87
150
)
88
- ]
151
+
152
+ self .widget = VBox ([self .plot , self .slider ])
153
+
154
+ elif self .plot_type == GridPlot :
155
+ pass
156
+
157
+ def set_frame_slider_width (self ):
158
+ w , h = self .plot .renderer .logical_size
159
+ self .slider .layout = Layout (width = f"{ w } px" )
89
160
90
161
def slider_changed (self ):
91
162
pass
0 commit comments