Update hyperopt.py

This commit is contained in:
Italo 2022-03-20 16:12:06 +00:00
parent 112738d68d
commit fca93d8dfe

View File

@ -414,30 +414,33 @@ class Hyperopt:
# Store non-trimmed data - will be trimmed after signal generation. # Store non-trimmed data - will be trimmed after signal generation.
dump(preprocessed, self.data_pickle_file) dump(preprocessed, self.data_pickle_file)
def get_asked_points(self, n_points: int) -> List[Any]: def get_asked_points(self, n_points: int) -> List[List[Any]]:
''' '''
Enforce points returned from `self.opt.ask` have not been already evaluated
Steps: Steps:
1. Try to get points using `self.opt.ask` first 1. Try to get points using `self.opt.ask` first
2. Discard the points that have already been evaluated 2. Discard the points that have already been evaluated
3. Retry using `self.opt.ask` up to 3 times 3. Retry using `self.opt.ask` up to 3 times
4. If still some points are missing in respect to `n_points`, random sample some points 4. If still some points are missing in respect to `n_points`, random sample some points
5. Repeat until at least `n_points` points in the `asked_non_tried` list 5. Repeat until at least `n_points` points in the `asked_non_tried` list
6. Return a list with legth truncated at `n_points` 6. Return a list with length truncated at `n_points`
''' '''
i = 0 i = 0
asked_non_tried = [] asked_non_tried: List[List[Any]] = []
while i < 100: while i < 100 and len(asked_non_tried) < n_points:
if len(asked_non_tried) < n_points:
if i < 3: if i < 3:
asked = self.opt.ask(n_points=n_points) asked = self.opt.ask(n_points=n_points)
else: else:
# use random sample if `self.opt.ask` returns points points already tried
asked = self.opt.space.rvs(n_samples=n_points * 5) asked = self.opt.space.rvs(n_samples=n_points * 5)
asked_non_tried += [x for x in asked if x not in self.opt.Xi and x not in asked_non_tried] asked_non_tried += [x for x in asked
if x not in self.opt.Xi
and x not in asked_non_tried]
i += 1 i += 1
else: if asked_non_tried:
break
return asked_non_tried[:n_points] return asked_non_tried[:n_points]
else:
return self.opt.ask(n_points=n_points)
def start(self) -> None: def start(self) -> None:
self.random_state = self._set_random_state(self.config.get('hyperopt_random_state', None)) self.random_state = self._set_random_state(self.config.get('hyperopt_random_state', None))