better fit logic (and multi-opt was never fitting -_-)

This commit is contained in:
orehunt 2020-03-13 18:57:05 +01:00
parent 027dae1c9b
commit ef6efb7117

View File

@ -605,7 +605,7 @@ class Hyperopt:
i, backend.optimizers, jobs, backend.results_board)
for i in range(first_try, first_try + tries))
# each worker will return a list containing n_points, so compact into a single list
return functools.reduce(lambda x, y: [*x, *y], results)
return functools.reduce(lambda x, y: [*x, *y], results, [])
def opt_ask_and_tell(self, jobs: int, tries: int):
"""
@ -617,7 +617,6 @@ class Hyperopt:
vals = []
to_ask: deque = deque()
evald: Set[Tuple] = set()
fit = False
opt = self.opt
for r in range(tries):
while not backend.results.empty():
@ -628,15 +627,11 @@ class Hyperopt:
if void_filtered: # again if all are filtered
opt.tell([list(v['params_dict'].values()) for v in void_filtered],
[v['loss'] for v in void_filtered],
fit=fit)
if fit:
fit = False
fit=(len(to_ask) < 1)) # only fit when out of points
del vals[:], void_filtered[:]
if not to_ask:
opt.update_next()
to_ask.extend(opt.ask(n_points=self.n_points))
fit = True
a = tuple(to_ask.popleft())
while a in evald:
logger.info("this point was evaluated before...")
@ -675,8 +670,7 @@ class Hyperopt:
# put back the updated results
results_board.put(results)
if len(past_Xi) > 0:
opt.tell(past_Xi, past_yi, fit=False)
opt.update_next()
opt.tell(past_Xi, past_yi, fit=True)
# ask for points according to config
asked = opt.ask(n_points=self.n_points, strategy=self.get_next_point_strategy())
@ -688,7 +682,9 @@ class Hyperopt:
if opt.void_loss != VOID_LOSS or len(void_filtered) > 0:
Xi = [list(v['params_dict'].values()) for v in void_filtered]
yi = [v['loss'] for v in void_filtered]
opt.tell(Xi, yi, fit=False)
# because we fit with points from other runs
# only fit if at the current dispatch there were no points
opt.tell(Xi, yi, fit=(len(past_Xi) < 1))
# update the board with the new results
results = results_board.get()
results.append([void_filtered, jobs - 1])
@ -956,6 +952,9 @@ class Hyperopt:
f_val = jobs_scheduler(parallel, batch_len, epochs_so_far, self.n_jobs)
saved = self.log_results(f_val, epochs_so_far, epochs_limit())
# stop if no epochs have been evaluated
if len(f_val) < 1:
logger.warning("All epochs evaluated were void, "
"check the loss function and the search space.")
if (not saved and len(f_val) > 1) or batch_len < 1:
break
# log_results add