3
3
import itertools
4
4
import sys
5
5
from collections import defaultdict
6
- from collections .abc import Iterable , Sequence
6
+ from collections .abc import Generator , Iterable , Sequence
7
7
from contextlib import suppress
8
8
from functools import partial
9
9
from operator import itemgetter
@@ -126,11 +126,10 @@ def __init__(
126
126
self ._cdims_default = cdims
127
127
128
128
if len ({learner .__class__ for learner in self .learners }) > 1 :
129
- raise TypeError (
130
- "A BalacingLearner can handle only one type" " of learners."
131
- )
129
+ raise TypeError ("A BalacingLearner can handle only one type of learners." )
132
130
133
131
self .strategy : STRATEGY_TYPE = strategy
132
+ self ._gen : Generator | None = None
134
133
135
134
def new (self ) -> BalancingLearner :
136
135
"""Create a new `BalancingLearner` with the same parameters."""
@@ -288,27 +287,16 @@ def _ask_and_tell_based_on_cycle(
288
287
def _ask_and_tell_based_on_sequential (
289
288
self , n : int
290
289
) -> tuple [list [tuple [Int , Any ]], list [float ]]:
290
+ if self ._gen is None :
291
+ self ._gen = _sequential_generator (self .learners )
291
292
points : list [tuple [Int , Any ]] = []
292
293
loss_improvements : list [float ] = []
293
- learner_index = 0
294
-
295
- while len (points ) < n :
296
- learner = self .learners [learner_index ]
297
- if learner .done (): # type: ignore[attr-defined]
298
- if learner_index == len (self .learners ) - 1 :
299
- break
300
- learner_index += 1
301
- continue
302
-
303
- point , loss_improvement = learner .ask (n = 1 )
304
- if not point : # if learner is exhausted, we don't get points
305
- if learner_index == len (self .learners ) - 1 :
306
- break
307
- learner_index += 1
308
- continue
309
- points .append ((learner_index , point [0 ]))
310
- loss_improvements .append (loss_improvement [0 ])
311
- self .tell_pending ((learner_index , point [0 ]))
294
+ for learner_index , point , loss_improvement in self ._gen :
295
+ points .append ((learner_index , point ))
296
+ loss_improvements .append (loss_improvement )
297
+ self .tell_pending ((learner_index , point ))
298
+ if len (points ) >= n :
299
+ break
312
300
313
301
return points , loss_improvements
314
302
@@ -629,3 +617,27 @@ def __getstate__(self) -> tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]:
629
617
def __setstate__ (self , state : tuple [list [BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ]):
630
618
learners , cdims , strategy = state
631
619
self .__init__ (learners , cdims = cdims , strategy = strategy ) # type: ignore[misc]
620
+
621
+
622
+ def _sequential_generator (
623
+ learners : list [BaseLearner ],
624
+ ) -> Generator [tuple [int , Any , float ], None , None ]:
625
+ learner_index = 0
626
+ if not hasattr (learners [0 ], "done" ):
627
+ msg = "All learners must have a `done` method to use the 'sequential' strategy."
628
+ raise ValueError (msg )
629
+ while True :
630
+ learner = learners [learner_index ]
631
+ if learner .done (): # type: ignore[attr-defined]
632
+ if learner_index == len (learners ) - 1 :
633
+ return
634
+ learner_index += 1
635
+ continue
636
+
637
+ point , loss_improvement = learner .ask (n = 1 )
638
+ if not point : # if learner is exhausted, we don't get points
639
+ if learner_index == len (learners ) - 1 :
640
+ return
641
+ learner_index += 1
642
+ continue
643
+ yield learner_index , point [0 ], loss_improvement [0 ]
0 commit comments