5
5
from collections import defaultdict
6
6
from copy import deepcopy
7
7
from math import hypot
8
- from numbers import Integral as Int
9
- from numbers import Real
10
8
from typing import Callable , DefaultDict , Iterable , List , Sequence , Tuple
11
9
12
10
import numpy as np
16
14
17
15
from adaptive .learner .learner1D import Learner1D , _get_intervals
18
16
from adaptive .notebook_integration import ensure_holoviews
17
+ from adaptive .types import Int , Real
19
18
from adaptive .utils import assign_defaults , partial_function_from_dataframe
20
19
21
20
try :
@@ -99,7 +98,7 @@ def __init__(
99
98
if min_samples > max_samples :
100
99
raise ValueError ("max_samples should be larger than min_samples." )
101
100
102
- super ().__init__ (function , bounds , loss_per_interval )
101
+ super ().__init__ (function , bounds , loss_per_interval ) # type: ignore[arg-type]
103
102
104
103
self .delta = delta
105
104
self .alpha = alpha
@@ -110,7 +109,7 @@ def __init__(
110
109
111
110
# Contains all samples f(x) for each
112
111
# point x in the form {x0: {0: f_0(x0), 1: f_1(x0), ...}, ...}
113
- self ._data_samples = SortedDict ()
112
+ self ._data_samples : SortedDict [ float , dict [ int , Real ]] = SortedDict ()
114
113
# Contains the number of samples taken
115
114
# at each point x in the form {x0: n0, x1: n1, ...}
116
115
self ._number_samples = SortedDict ()
@@ -124,14 +123,14 @@ def __init__(
124
123
# form {xi: ((xii-xi)^2 + (yii-yi)^2)^0.5, ...}
125
124
self ._distances : dict [Real , float ] = decreasing_dict ()
126
125
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
127
- self .rescaled_error : dict [Real , float ] = decreasing_dict ()
126
+ self .rescaled_error : ItemSortedDict [Real , float ] = decreasing_dict ()
128
127
129
128
def new (self ) -> AverageLearner1D :
130
129
"""Create a copy of `~adaptive.AverageLearner1D` without the data."""
131
130
return AverageLearner1D (
132
131
self .function ,
133
132
self .bounds ,
134
- self .loss_per_interval ,
133
+ self .loss_per_interval , # type: ignore[arg-type]
135
134
self .delta ,
136
135
self .alpha ,
137
136
self .neighbor_sampling ,
@@ -163,7 +162,7 @@ def to_numpy(self, mean: bool = False) -> np.ndarray:
163
162
]
164
163
)
165
164
166
- def to_dataframe (
165
+ def to_dataframe ( # type: ignore[override]
167
166
self ,
168
167
mean : bool = False ,
169
168
with_default_function_args : bool = True ,
@@ -201,10 +200,10 @@ def to_dataframe(
201
200
if not with_pandas :
202
201
raise ImportError ("pandas is not installed." )
203
202
if mean :
204
- data = sorted (self .data .items ())
203
+ data : list [ tuple [ Real , Real ]] = sorted (self .data .items ())
205
204
columns = [x_name , y_name ]
206
205
else :
207
- data = [
206
+ data : list [ tuple [ int , Real , Real ]] = [ # type: ignore[no-redef]
208
207
(seed , x , y )
209
208
for x , seed_y in sorted (self ._data_samples .items ())
210
209
for seed , y in sorted (seed_y .items ())
@@ -217,7 +216,7 @@ def to_dataframe(
217
216
assign_defaults (self .function , df , function_prefix )
218
217
return df
219
218
220
- def load_dataframe (
219
+ def load_dataframe ( # type: ignore[override]
221
220
self ,
222
221
df : pandas .DataFrame ,
223
222
with_default_function_args : bool = True ,
@@ -257,7 +256,7 @@ def load_dataframe(
257
256
self .function , df , function_prefix
258
257
)
259
258
260
- def ask (self , n : int , tell_pending : bool = True ) -> tuple [Points , list [float ]]:
259
+ def ask (self , n : int , tell_pending : bool = True ) -> tuple [Points , list [float ]]: # type: ignore[override]
261
260
"""Return 'n' points that are expected to maximally reduce the loss."""
262
261
# If some point is undersampled, resample it
263
262
if len (self ._undersampled_points ):
@@ -310,18 +309,18 @@ def _ask_for_new_point(self, n: int) -> tuple[Points, list[float]]:
310
309
new point, since in general n << min_samples and this point will need
311
310
to be resampled many more times"""
312
311
points , (loss_improvement ,) = self ._ask_points_without_adding (1 )
313
- points = [(seed , x ) for seed , x in zip (range (n ), n * points )]
312
+ seed_points = [(seed , x ) for seed , x in zip (range (n ), n * points )]
314
313
loss_improvements = [loss_improvement / n ] * n
315
- return points , loss_improvements
314
+ return seed_points , loss_improvements # type: ignore[return-value]
316
315
317
- def tell_pending (self , seed_x : Point ) -> None :
316
+ def tell_pending (self , seed_x : Point ) -> None : # type: ignore[override]
318
317
_ , x = seed_x
319
318
self .pending_points .add (seed_x )
320
319
if x not in self .data :
321
320
self ._update_neighbors (x , self .neighbors_combined )
322
321
self ._update_losses (x , real = False )
323
322
324
- def tell (self , seed_x : Point , y : Real ) -> None :
323
+ def tell (self , seed_x : Point , y : Real ) -> None : # type: ignore[override]
325
324
seed , x = seed_x
326
325
if y is None :
327
326
raise TypeError (
@@ -492,7 +491,7 @@ def _calc_error_in_mean(self, ys: Iterable[Real], y_avg: Real, n: int) -> float:
492
491
t_student = scipy .stats .t .ppf (1 - self .alpha , df = n - 1 )
493
492
return t_student * (variance_in_mean / n ) ** 0.5
494
493
495
- def tell_many (
494
+ def tell_many ( # type: ignore[override]
496
495
self , xs : Points | np .ndarray , ys : Sequence [Real ] | np .ndarray
497
496
) -> None :
498
497
# Check that all x are within the bounds
@@ -577,10 +576,10 @@ def tell_many_at_point(self, x: Real, seed_y_mapping: dict[int, Real]) -> None:
577
576
self ._update_interpolated_loss_in_interval (* interval )
578
577
self ._oldscale = deepcopy (self ._scale )
579
578
580
- def _get_data (self ) -> dict [Real , dict [Int , Real ]]:
579
+ def _get_data (self ) -> dict [Real , dict [Int , Real ]]: # type: ignore[override]
581
580
return self ._data_samples
582
581
583
- def _set_data (self , data : dict [Real , dict [Int , Real ]]) -> None :
582
+ def _set_data (self , data : dict [Real , dict [Int , Real ]]) -> None : # type: ignore[override]
584
583
if data :
585
584
for x , samples in data .items ():
586
585
self .tell_many_at_point (x , samples )
@@ -615,7 +614,7 @@ def plot(self):
615
614
return p .redim (x = {"range" : plot_bounds })
616
615
617
616
618
- def decreasing_dict () -> dict :
617
+ def decreasing_dict () -> ItemSortedDict :
619
618
"""This initialization orders the dictionary from large to small values"""
620
619
621
620
def sorting_rule (key , value ):
0 commit comments