better fit logic (and multi-opt was never fitting -_-)
This commit is contained in:
parent
027dae1c9b
commit
ef6efb7117
@ -605,7 +605,7 @@ class Hyperopt:
|
|||||||
i, backend.optimizers, jobs, backend.results_board)
|
i, backend.optimizers, jobs, backend.results_board)
|
||||||
for i in range(first_try, first_try + tries))
|
for i in range(first_try, first_try + tries))
|
||||||
# each worker will return a list containing n_points, so compact into a single list
|
# each worker will return a list containing n_points, so compact into a single list
|
||||||
return functools.reduce(lambda x, y: [*x, *y], results)
|
return functools.reduce(lambda x, y: [*x, *y], results, [])
|
||||||
|
|
||||||
def opt_ask_and_tell(self, jobs: int, tries: int):
|
def opt_ask_and_tell(self, jobs: int, tries: int):
|
||||||
"""
|
"""
|
||||||
@ -617,7 +617,6 @@ class Hyperopt:
|
|||||||
vals = []
|
vals = []
|
||||||
to_ask: deque = deque()
|
to_ask: deque = deque()
|
||||||
evald: Set[Tuple] = set()
|
evald: Set[Tuple] = set()
|
||||||
fit = False
|
|
||||||
opt = self.opt
|
opt = self.opt
|
||||||
for r in range(tries):
|
for r in range(tries):
|
||||||
while not backend.results.empty():
|
while not backend.results.empty():
|
||||||
@ -628,15 +627,11 @@ class Hyperopt:
|
|||||||
if void_filtered: # again if all are filtered
|
if void_filtered: # again if all are filtered
|
||||||
opt.tell([list(v['params_dict'].values()) for v in void_filtered],
|
opt.tell([list(v['params_dict'].values()) for v in void_filtered],
|
||||||
[v['loss'] for v in void_filtered],
|
[v['loss'] for v in void_filtered],
|
||||||
fit=fit)
|
fit=(len(to_ask) < 1)) # only fit when out of points
|
||||||
if fit:
|
|
||||||
fit = False
|
|
||||||
del vals[:], void_filtered[:]
|
del vals[:], void_filtered[:]
|
||||||
|
|
||||||
if not to_ask:
|
if not to_ask:
|
||||||
opt.update_next()
|
|
||||||
to_ask.extend(opt.ask(n_points=self.n_points))
|
to_ask.extend(opt.ask(n_points=self.n_points))
|
||||||
fit = True
|
|
||||||
a = tuple(to_ask.popleft())
|
a = tuple(to_ask.popleft())
|
||||||
while a in evald:
|
while a in evald:
|
||||||
logger.info("this point was evaluated before...")
|
logger.info("this point was evaluated before...")
|
||||||
@ -675,8 +670,7 @@ class Hyperopt:
|
|||||||
# put back the updated results
|
# put back the updated results
|
||||||
results_board.put(results)
|
results_board.put(results)
|
||||||
if len(past_Xi) > 0:
|
if len(past_Xi) > 0:
|
||||||
opt.tell(past_Xi, past_yi, fit=False)
|
opt.tell(past_Xi, past_yi, fit=True)
|
||||||
opt.update_next()
|
|
||||||
|
|
||||||
# ask for points according to config
|
# ask for points according to config
|
||||||
asked = opt.ask(n_points=self.n_points, strategy=self.get_next_point_strategy())
|
asked = opt.ask(n_points=self.n_points, strategy=self.get_next_point_strategy())
|
||||||
@ -688,7 +682,9 @@ class Hyperopt:
|
|||||||
if opt.void_loss != VOID_LOSS or len(void_filtered) > 0:
|
if opt.void_loss != VOID_LOSS or len(void_filtered) > 0:
|
||||||
Xi = [list(v['params_dict'].values()) for v in void_filtered]
|
Xi = [list(v['params_dict'].values()) for v in void_filtered]
|
||||||
yi = [v['loss'] for v in void_filtered]
|
yi = [v['loss'] for v in void_filtered]
|
||||||
opt.tell(Xi, yi, fit=False)
|
# because we fit with points from other runs
|
||||||
|
# only fit if at the current dispatch there were no points
|
||||||
|
opt.tell(Xi, yi, fit=(len(past_Xi) < 1))
|
||||||
# update the board with the new results
|
# update the board with the new results
|
||||||
results = results_board.get()
|
results = results_board.get()
|
||||||
results.append([void_filtered, jobs - 1])
|
results.append([void_filtered, jobs - 1])
|
||||||
@ -956,6 +952,9 @@ class Hyperopt:
|
|||||||
f_val = jobs_scheduler(parallel, batch_len, epochs_so_far, self.n_jobs)
|
f_val = jobs_scheduler(parallel, batch_len, epochs_so_far, self.n_jobs)
|
||||||
saved = self.log_results(f_val, epochs_so_far, epochs_limit())
|
saved = self.log_results(f_val, epochs_so_far, epochs_limit())
|
||||||
# stop if no epochs have been evaluated
|
# stop if no epochs have been evaluated
|
||||||
|
if len(f_val) < 1:
|
||||||
|
logger.warning("All epochs evaluated were void, "
|
||||||
|
"check the loss function and the search space.")
|
||||||
if (not saved and len(f_val) > 1) or batch_len < 1:
|
if (not saved and len(f_val) > 1) or batch_len < 1:
|
||||||
break
|
break
|
||||||
# log_results add
|
# log_results add
|
||||||
|
Loading…
Reference in New Issue
Block a user