@@ -170,6 +170,7 @@ def __init__(
170
170
vmin_vmax_sliders : bool = False ,
171
171
grid_shape : Tuple [int , int ] = None ,
172
172
names : List [str ] = None ,
173
+ grid_plot_kwargs : dict = None ,
173
174
** kwargs
174
175
):
175
176
"""
@@ -471,12 +472,12 @@ def __init__(
471
472
472
473
if vmin_vmax_sliders :
473
474
data_range = np .ptp (minmax )
474
- data_range_30p = np .ptp (minmax ) * 0.3
475
+ data_range_40p = np .ptp (minmax ) * 0.3
475
476
476
477
minmax_slider = FloatRangeSlider (
477
478
value = minmax ,
478
- min = minmax [0 ] - data_range_30p ,
479
- max = minmax [1 ] + data_range_30p ,
479
+ min = minmax [0 ] - data_range_40p ,
480
+ max = minmax [1 ] + data_range_40p ,
480
481
step = data_range / 150 ,
481
482
description = f"min-max" ,
482
483
readout = True ,
@@ -494,11 +495,15 @@ def __init__(
494
495
kwargs ["vmin" ], kwargs ["vmax" ] = minmax
495
496
496
497
frame = self ._process_indices (self .data [0 ], slice_indices = self ._current_index )
498
+ frame = self ._process_frame_apply (frame , 0 )
497
499
498
500
self .image_graphics : List [ImageGraphic ] = [self .plot .add_image (data = frame , name = "image" , ** kwargs )]
499
501
500
502
elif self ._plot_type == "grid" :
501
- self ._plot : GridPlot = GridPlot (shape = grid_shape , controllers = "sync" )
503
+ if grid_plot_kwargs is None :
504
+ grid_plot_kwargs = {"controllers" : "sync" }
505
+
506
+ self ._plot : GridPlot = GridPlot (shape = grid_shape , ** grid_plot_kwargs )
502
507
503
508
self .image_graphics = list ()
504
509
for data_ix , (d , subplot ) in enumerate (zip (self .data , self .plot )):
@@ -513,12 +518,12 @@ def __init__(
513
518
514
519
if vmin_vmax_sliders :
515
520
data_range = np .ptp (minmax )
516
- data_range_30p = np .ptp (minmax ) * 0.4
521
+ data_range_40p = np .ptp (minmax ) * 0.4
517
522
518
523
minmax_slider = FloatRangeSlider (
519
524
value = minmax ,
520
- min = minmax [0 ] - data_range_30p ,
521
- max = minmax [1 ] + data_range_30p ,
525
+ min = minmax [0 ] - data_range_40p ,
526
+ max = minmax [1 ] + data_range_40p ,
522
527
step = data_range / 150 ,
523
528
description = f"mm: { name_slider } " ,
524
529
readout = True ,
@@ -539,6 +544,7 @@ def __init__(
539
544
_kwargs = kwargs
540
545
541
546
frame = self ._process_indices (d , slice_indices = self ._current_index )
547
+ frame = self ._process_frame_apply (frame , data_ix )
542
548
ig = ImageGraphic (frame , name = "image" , ** _kwargs )
543
549
subplot .add_graphic (ig )
544
550
subplot .name = name
@@ -767,11 +773,17 @@ def _get_window_indices(self, data_ix, dim, indices_dim):
767
773
return indices_dim
768
774
769
775
def _process_frame_apply (self , array , data_ix ) -> np .ndarray :
776
+ if callable (self .frame_apply ):
777
+ return self .frame_apply (array )
778
+
770
779
if data_ix not in self .frame_apply .keys ():
771
780
return array
772
- if self .frame_apply [data_ix ] is not None :
781
+
782
+ elif self .frame_apply [data_ix ] is not None :
773
783
return self .frame_apply [data_ix ](array )
774
784
785
+ return array
786
+
775
787
def _slider_value_changed (
776
788
self ,
777
789
dimension : str ,
@@ -801,6 +813,32 @@ def _set_slider_layout(self, *args):
801
813
for mm in self .vmin_vmax_sliders :
802
814
mm .layout = Layout (width = f"{ w } px" )
803
815
816
+ def _get_vmin_vmax_range (self , data : np .ndarray ) -> Tuple [int , int ]:
817
+ minmax = quick_min_max (data )
818
+
819
+ data_range = np .ptp (minmax )
820
+ data_range_40p = np .ptp (minmax ) * 0.4
821
+
822
+ _range = (
823
+ minmax ,
824
+ data_range ,
825
+ minmax [0 ] - data_range_40p ,
826
+ minmax [1 ] + data_range_40p
827
+ )
828
+
829
+ return _range
830
+
831
+ def reset_vmin_vmax (self ):
832
+ """
833
+ Reset the vmin and vmax w.r.t. the currently displayed image(s)
834
+ """
835
+ for i , ig in enumerate (self .image_graphics ):
836
+ mm = self ._get_vmin_vmax_range (ig .data ())
837
+ self .vmin_vmax_sliders [i ].min = mm [2 ]
838
+ self .vmin_vmax_sliders [i ].max = mm [3 ]
839
+ self .vmin_vmax_sliders [i ].step = mm [1 ] / 150
840
+ self .vmin_vmax_sliders [i ].value = mm [0 ]
841
+
804
842
def show (self ):
805
843
"""
806
844
Show the widget
0 commit comments