fixes, moved points setup to its function

This commit is contained in:
orehunt 2020-03-10 09:10:10 +01:00
parent 29e9faf167
commit ece0ddba38

View File

@ -15,7 +15,7 @@ from numpy import iinfo, int32
from operator import itemgetter
from pathlib import Path
from pprint import pprint
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple, Set
import rapidjson
from colorama import Fore, Style
@ -55,7 +55,7 @@ logger = logging.getLogger(__name__)
NEXT_POINT_METHODS = ["cl_min", "cl_mean", "cl_max"]
NEXT_POINT_METHODS_LENGTH = 3
VOID_LOSS = iinfo(int32).max # just a big enough number to be bad result in loss optimization
VOID_LOSS = iinfo(int32).max # just a big enough number to be bad result in loss optimization
class Hyperopt:
@ -555,7 +555,7 @@ class Hyperopt:
# only exclude results at the beginning when void loss is yet to be set
void_filtered = list(filter(lambda v: v["loss"] != VOID_LOSS, vals))
else:
if opt.void_loss == VOID_LOSS: # set void loss once
if opt.void_loss == VOID_LOSS: # set void loss once
opt.void_loss = max(opt.yi)
void_filtered = []
# default bad losses to set void_loss
@ -616,7 +616,7 @@ class Hyperopt:
"""
vals = []
to_ask: deque = deque()
evald: set(Tuple) = set()
evald: Set[Tuple] = set()
fit = False
opt = self.opt
for r in range(tries):
@ -625,10 +625,10 @@ class Hyperopt:
if vals:
# filter losses
void_filtered = self.filter_void_losses(vals, opt)
if vals: # again if all are filtered
if void_filtered: # again if all are filtered
opt.tell([list(v['params_dict'].values()) for v in void_filtered],
[v['loss'] for v in vals],
fit=fit)
[v['loss'] for v in void_filtered],
fit=fit)
if fit:
fit = False
del vals[:], void_filtered[:]
@ -841,7 +841,7 @@ class Hyperopt:
for _ in range(remaining): # generate optimizers
# random state is preserved
opt_copy = opt.copy(random_state=opt.rng.randint(0,
iinfo(int32).max))
iinfo(int32).max))
opt_copy.void_loss = VOID_LOSS
backend.optimizers.put(opt_copy)
del opt, opt_copy
@ -857,6 +857,23 @@ class Hyperopt:
self.opt.void_loss = VOID_LOSS
del opts[:]
def setup_points(self):
self.n_initial_points, self.min_epochs, self.search_space_size = self.calc_epochs(
self.dimensions, self.n_jobs, self.effort, self.total_epochs, self.n_points
)
logger.info(f"Min epochs set to: {self.min_epochs}")
# reduce random points by n_points in multi mode because asks are per job
if self.multi:
self.opt_n_initial_points = self.n_initial_points // self.n_points
else:
self.opt_n_initial_points = self.n_initial_points
logger.info(f'Initial points: {self.n_initial_points}')
# if total epochs are not set, max_epoch takes its place
if self.total_epochs < 1:
self.max_epoch = int(self.min_epochs + len(self.trials))
# initialize average best occurrence
self.avg_best_occurrence = self.min_epochs // self.n_jobs
def start(self) -> None:
""" Broom Broom """
self.random_state = self._set_random_state(self.config.get('hyperopt_random_state', None))
@ -887,22 +904,8 @@ class Hyperopt:
logger.info(f'Number of parallel jobs set as: {self.n_jobs}')
self.dimensions: List[Dimension] = self.hyperopt_space()
self.n_initial_points, self.min_epochs, self.search_space_size = self.calc_epochs(
self.dimensions, self.n_jobs, self.effort, self.total_epochs, self.n_points
)
# reduce random points by n_points in multi mode because asks are per job
if self.multi:
self.opt_n_initial_points = self.n_initial_points // self.n_points
else:
self.opt_n_initial_points = self.n_initial_points
logger.info(f"Min epochs set to: {self.min_epochs}")
# if total epochs are not set, max_epoch takes its place
if self.total_epochs < 1:
self.max_epoch = int(self.min_epochs + len(self.trials))
# initialize average best occurrence
self.avg_best_occurrence = self.min_epochs // self.n_jobs
self.setup_points()
logger.info(f'Initial points: {self.n_initial_points}')
if self.print_colorized:
colorama_init(autoreset=True)