Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit aee9e17

Browse filesBrowse files
committed
Implement using generator
1 parent 3523343 commit aee9e17
Copy full SHA for aee9e17

File tree

1 file changed

+35
-23
lines changed
Filter options

1 file changed

+35
-23
lines changed

‎adaptive/learner/balancing_learner.py

Copy file name to clipboardExpand all lines: adaptive/learner/balancing_learner.py
+35-23
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import itertools
44
import sys
55
from collections import defaultdict
6-
from collections.abc import Iterable, Sequence
6+
from collections.abc import Generator, Iterable, Sequence
77
from contextlib import suppress
88
from functools import partial
99
from operator import itemgetter
@@ -126,11 +126,10 @@ def __init__(
126126
self._cdims_default = cdims
127127

128128
if len({learner.__class__ for learner in self.learners}) > 1:
129-
raise TypeError(
130-
"A BalacingLearner can handle only one type" " of learners."
131-
)
129+
raise TypeError("A BalacingLearner can handle only one type of learners.")
132130

133131
self.strategy: STRATEGY_TYPE = strategy
132+
self._gen: Generator | None = None
134133

135134
def new(self) -> BalancingLearner:
136135
"""Create a new `BalancingLearner` with the same parameters."""
@@ -288,27 +287,16 @@ def _ask_and_tell_based_on_cycle(
288287
def _ask_and_tell_based_on_sequential(
289288
self, n: int
290289
) -> tuple[list[tuple[Int, Any]], list[float]]:
290+
if self._gen is None:
291+
self._gen = _sequential_generator(self.learners)
291292
points: list[tuple[Int, Any]] = []
292293
loss_improvements: list[float] = []
293-
learner_index = 0
294-
295-
while len(points) < n:
296-
learner = self.learners[learner_index]
297-
if learner.done(): # type: ignore[attr-defined]
298-
if learner_index == len(self.learners) - 1:
299-
break
300-
learner_index += 1
301-
continue
302-
303-
point, loss_improvement = learner.ask(n=1)
304-
if not point: # if learner is exhausted, we don't get points
305-
if learner_index == len(self.learners) - 1:
306-
break
307-
learner_index += 1
308-
continue
309-
points.append((learner_index, point[0]))
310-
loss_improvements.append(loss_improvement[0])
311-
self.tell_pending((learner_index, point[0]))
294+
for learner_index, point, loss_improvement in self._gen:
295+
points.append((learner_index, point))
296+
loss_improvements.append(loss_improvement)
297+
self.tell_pending((learner_index, point))
298+
if len(points) >= n:
299+
break
312300

313301
return points, loss_improvements
314302

@@ -629,3 +617,27 @@ def __getstate__(self) -> tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]:
629617
def __setstate__(self, state: tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]):
630618
learners, cdims, strategy = state
631619
self.__init__(learners, cdims=cdims, strategy=strategy) # type: ignore[misc]
620+
621+
622+
def _sequential_generator(
623+
learners: list[BaseLearner],
624+
) -> Generator[tuple[int, Any, float], None, None]:
625+
learner_index = 0
626+
if not hasattr(learners[0], "done"):
627+
msg = "All learners must have a `done` method to use the 'sequential' strategy."
628+
raise ValueError(msg)
629+
while True:
630+
learner = learners[learner_index]
631+
if learner.done(): # type: ignore[attr-defined]
632+
if learner_index == len(learners) - 1:
633+
return
634+
learner_index += 1
635+
continue
636+
637+
point, loss_improvement = learner.ask(n=1)
638+
if not point: # if learner is exhausted, we don't get points
639+
if learner_index == len(learners) - 1:
640+
return
641+
learner_index += 1
642+
continue
643+
yield learner_index, point[0], loss_improvement[0]

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.