more fixes for epochs counting

This commit is contained in:
orehunt 2020-03-12 12:55:00 +01:00
parent ece0ddba38
commit 027dae1c9b

View File

@ -716,9 +716,9 @@ class Hyperopt:
print()
current = frame_start + 1
i = 0
for i, v in enumerate(f_val):
for i, v in enumerate(f_val, 1):
is_best = self.is_best_loss(v, self.current_best_loss)
current = frame_start + i + 1
current = frame_start + i
v['is_best'] = is_best
v['current_epoch'] = current
v['is_initial_point'] = current <= self.n_initial_points
@ -735,6 +735,20 @@ class Hyperopt:
self.max_epoch_reached = True
return i
def setup_best_epochs(self) -> bool:
""" used to resume the best epochs state from previous trials """
len_trials = len(self.trials)
if len_trials > 0:
best_epochs = list(filter(lambda k: k["is_best"], self.trials))
len_best = len(best_epochs)
if len_best > 0:
# sorting from lowest to highest, the first value is the current best
best = sorted(best_epochs, key=lambda k: k["loss"])[0]
self.current_best_epoch = best["current_epoch"]
self.avg_best_occurrence = len_trials // len_best
return True
return False
@staticmethod
def load_previous_results(trials_file: Path) -> List:
"""
@ -790,20 +804,21 @@ class Hyperopt:
(factorial(n_parameters) /
(factorial(n_parameters - n_dimensions) * factorial(n_dimensions))))
# logger.info(f'Search space size: {search_space_size}')
log_opt = int(log(opt_points, 2)) if opt_points > 4 else 2
if search_space_size < opt_points:
# don't waste if the space is small
n_initial_points = opt_points // 3
min_epochs = opt_points
elif total_epochs > 0:
n_initial_points = total_epochs // 3 if total_epochs > opt_points * 3 else opt_points
# coefficients from total epochs
log_epp = int(log(total_epochs, 2)) * log_opt
n_initial_points = min(log_epp, total_epochs // 3)
min_epochs = total_epochs
else:
# extract coefficients from the search space and the jobs count
log_sss = int(log(search_space_size, 10))
log_opt = int(log(opt_points, 2)) if opt_points > 4 else 2
opt_ip = log_opt * log_sss
# extract coefficients from the search space
log_sss = int(log(search_space_size, 10)) * log_opt
# never waste
n_initial_points = log_sss if opt_ip > search_space_size else opt_ip
n_initial_points = min(log_sss, search_space_size // 3)
# it shall run for this much, I say
min_epochs = int(max(n_initial_points, opt_points) * (1 + effort) + n_initial_points)
return n_initial_points, min_epochs, search_space_size
@ -899,6 +914,7 @@ class Hyperopt:
self.backtesting.exchange = None # type: ignore
self.trials = self.load_previous_results(self.trials_file)
self.setup_best_epochs()
logger.info(f"Found {cpu_count()} CPU cores. Let's make them scream!")
logger.info(f'Number of parallel jobs set as: {self.n_jobs}')
@ -918,9 +934,11 @@ class Hyperopt:
with parallel_backend('loky', inner_max_num_threads=2):
with Parallel(n_jobs=self.n_jobs, verbose=0, backend='loky') as parallel:
# update epochs count
n_points = self.n_points
prev_batch = -1
epochs_so_far = len(self.trials)
while prev_batch < epochs_so_far:
epochs_limit = self.epochs_limit
while epochs_so_far > prev_batch or epochs_so_far < self.min_epochs:
prev_batch = epochs_so_far
# pad the batch length to the number of jobs to avoid desaturation
batch_len = (self.avg_best_occurrence + self.n_jobs -
@ -929,16 +947,16 @@ class Hyperopt:
# n_points (epochs) in 1 dispatch but this reduces the batch len too much
# if self.multi: batch_len = batch_len // self.n_points
# don't go over the limit
if epochs_so_far + batch_len > self.epochs_limit():
batch_len = self.epochs_limit() - epochs_so_far
if epochs_so_far + batch_len * n_points > epochs_limit():
batch_len = (epochs_limit() - epochs_so_far) // n_points
print(
f"{epochs_so_far+1}-{epochs_so_far+batch_len}"
f"/{self.epochs_limit()}: ",
f"{epochs_so_far+1}-{epochs_so_far+batch_len*n_points}"
f"/{epochs_limit()}: ",
end='')
f_val = jobs_scheduler(parallel, batch_len, epochs_so_far, self.n_jobs)
saved = self.log_results(f_val, epochs_so_far, self.epochs_limit())
saved = self.log_results(f_val, epochs_so_far, epochs_limit())
# stop if no epochs have been evaluated
if not saved or batch_len < 1:
if (not saved and len(f_val) > 1) or batch_len < 1:
break
# log_results add
epochs_so_far += saved