Hyperopt cleanup, do not use 'trials'

This commit is contained in:
hroff-1902 2020-04-28 22:56:19 +03:00
parent 09e488a693
commit c26835048c
2 changed files with 59 additions and 54 deletions

View File

@ -75,8 +75,8 @@ class Hyperopt:
self.custom_hyperoptloss = HyperOptLossResolver.load_hyperoptloss(self.config) self.custom_hyperoptloss = HyperOptLossResolver.load_hyperoptloss(self.config)
self.calculate_loss = self.custom_hyperoptloss.hyperopt_loss_function self.calculate_loss = self.custom_hyperoptloss.hyperopt_loss_function
self.trials_file = (self.config['user_data_dir'] / self.results_file = (self.config['user_data_dir'] /
'hyperopt_results' / 'hyperopt_results.pickle') 'hyperopt_results' / 'hyperopt_results.pickle')
self.data_pickle_file = (self.config['user_data_dir'] / self.data_pickle_file = (self.config['user_data_dir'] /
'hyperopt_results' / 'hyperopt_tickerdata.pkl') 'hyperopt_results' / 'hyperopt_tickerdata.pkl')
self.total_epochs = config.get('epochs', 0) self.total_epochs = config.get('epochs', 0)
@ -88,10 +88,10 @@ class Hyperopt:
else: else:
logger.info("Continuing on previous hyperopt results.") logger.info("Continuing on previous hyperopt results.")
self.num_trials_saved = 0 self.num_epochs_saved = 0
# Previous evaluations # Previous evaluations
self.trials: List = [] self.epochs: List = []
# Populate functions here (hasattr is slow so should not be run during "regular" operations) # Populate functions here (hasattr is slow so should not be run during "regular" operations)
if hasattr(self.custom_hyperopt, 'populate_indicators'): if hasattr(self.custom_hyperopt, 'populate_indicators'):
@ -132,7 +132,7 @@ class Hyperopt:
""" """
Remove hyperopt pickle files to restart hyperopt. Remove hyperopt pickle files to restart hyperopt.
""" """
for f in [self.data_pickle_file, self.trials_file]: for f in [self.data_pickle_file, self.results_file]:
p = Path(f) p = Path(f)
if p.is_file(): if p.is_file():
logger.info(f"Removing `{p}`.") logger.info(f"Removing `{p}`.")
@ -151,27 +151,26 @@ class Hyperopt:
# and the values are taken from the list of parameters. # and the values are taken from the list of parameters.
return {d.name: v for d, v in zip(dimensions, raw_params)} return {d.name: v for d, v in zip(dimensions, raw_params)}
def save_trials(self, final: bool = False) -> None: def _save_results(self) -> None:
""" """
Save hyperopt trials to file Save hyperopt results to file
""" """
num_trials = len(self.trials) num_epochs = len(self.epochs)
if num_trials > self.num_trials_saved: if num_epochs > self.num_epochs_saved:
logger.debug(f"Saving {num_trials} {plural(num_trials, 'epoch')}.") logger.debug(f"Saving {num_epochs} {plural(num_epochs, 'epoch')}.")
dump(self.trials, self.trials_file) dump(self.epochs, self.results_file)
self.num_trials_saved = num_trials self.num_epochs_saved = num_epochs
if final: logger.debug(f"{self.num_epochs_saved} {plural(self.num_epochs_saved, 'epoch')} "
logger.info(f"{num_trials} {plural(num_trials, 'epoch')} " f"saved to '{self.results_file}'.")
f"saved to '{self.trials_file}'.")
@staticmethod @staticmethod
def _read_trials(trials_file: Path) -> List: def _read_results(results_file: Path) -> List:
""" """
Read hyperopt trials file Read hyperopt results from file
""" """
logger.info("Reading Trials from '%s'", trials_file) logger.info("Reading epochs from '%s'", results_file)
trials = load(trials_file) data = load(results_file)
return trials return data
def _get_params_details(self, params: Dict) -> Dict: def _get_params_details(self, params: Dict) -> Dict:
""" """
@ -588,19 +587,20 @@ class Hyperopt:
wrap_non_picklable_objects(self.generate_optimizer))(v, i) for v in asked) wrap_non_picklable_objects(self.generate_optimizer))(v, i) for v in asked)
@staticmethod @staticmethod
def load_previous_results(trials_file: Path) -> List: def load_previous_results(results_file: Path) -> List:
""" """
Load data for epochs from the file if we have one Load data for epochs from the file if we have one
""" """
trials: List = [] epochs: List = []
if trials_file.is_file() and trials_file.stat().st_size > 0: if results_file.is_file() and results_file.stat().st_size > 0:
trials = Hyperopt._read_trials(trials_file) epochs = Hyperopt._read_results(results_file)
if trials[0].get('is_best') is None: # Detection of some old format, without 'is_best' field saved
if epochs[0].get('is_best') is None:
raise OperationalException( raise OperationalException(
"The file with Hyperopt results is incompatible with this version " "The file with Hyperopt results is incompatible with this version "
"of Freqtrade and cannot be loaded.") "of Freqtrade and cannot be loaded.")
logger.info(f"Loaded {len(trials)} previous evaluations from disk.") logger.info(f"Loaded {len(epochs)} previous evaluations from disk.")
return trials return epochs
def _set_random_state(self, random_state: Optional[int]) -> int: def _set_random_state(self, random_state: Optional[int]) -> int:
return random_state or random.randint(1, 2**16 - 1) return random_state or random.randint(1, 2**16 - 1)
@ -628,7 +628,7 @@ class Hyperopt:
self.backtesting.exchange = None # type: ignore self.backtesting.exchange = None # type: ignore
self.backtesting.pairlists = None # type: ignore self.backtesting.pairlists = None # type: ignore
self.trials = self.load_previous_results(self.trials_file) self.epochs = self.load_previous_results(self.results_file)
cpus = cpu_count() cpus = cpu_count()
logger.info(f"Found {cpus} CPU cores. Let's make them scream!") logger.info(f"Found {cpus} CPU cores. Let's make them scream!")
@ -698,23 +698,25 @@ class Hyperopt:
if is_best: if is_best:
self.current_best_loss = val['loss'] self.current_best_loss = val['loss']
self.trials.append(val) self.epochs.append(val)
# Save results after each best epoch and every 100 epochs # Save results after each best epoch and every 100 epochs
if is_best or current % 100 == 0: if is_best or current % 100 == 0:
self.save_trials() self._save_results()
pbar.update(current) pbar.update(current)
except KeyboardInterrupt: except KeyboardInterrupt:
print('User interrupted..') print('User interrupted..')
self.save_trials(final=True) self._save_results()
logger.info(f"{self.num_epochs_saved} {plural(self.num_epochs_saved, 'epoch')} "
f"saved to '{self.results_file}'.")
if self.trials: if self.epochs:
sorted_trials = sorted(self.trials, key=itemgetter('loss')) sorted_epochs = sorted(self.epochs, key=itemgetter('loss'))
results = sorted_trials[0] best_epoch = sorted_epochs[0]
self.print_epoch_details(results, self.total_epochs, self.print_json) self.print_epoch_details(best_epoch, self.total_epochs, self.print_json)
else: else:
# This is printed when Ctrl+C is pressed quickly, before first epochs have # This is printed when Ctrl+C is pressed quickly, before first epochs have
# a chance to be evaluated. # a chance to be evaluated.

View File

@ -1,5 +1,6 @@
# pragma pylint: disable=missing-docstring,W0212,C0103 # pragma pylint: disable=missing-docstring,W0212,C0103
import locale import locale
import logging
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict, List
@ -56,14 +57,14 @@ def hyperopt_results():
# Functions for recurrent object patching # Functions for recurrent object patching
def create_trials(mocker, hyperopt, testdatadir) -> List[Dict]: def create_results(mocker, hyperopt, testdatadir) -> List[Dict]:
""" """
When creating trials, mock the hyperopt Trials so that *by default* When creating results, mock the hyperopt so that *by default*
- we don't create any pickle'd files in the filesystem - we don't create any pickle'd files in the filesystem
- we might have a pickle'd file so make sure that we return - we might have a pickle'd file so make sure that we return
false when looking for it false when looking for it
""" """
hyperopt.trials_file = testdatadir / 'optimize/ut_trials.pickle' hyperopt.results_file = testdatadir / 'optimize/ut_results.pickle'
mocker.patch.object(Path, "is_file", MagicMock(return_value=False)) mocker.patch.object(Path, "is_file", MagicMock(return_value=False))
stat_mock = MagicMock() stat_mock = MagicMock()
@ -477,28 +478,30 @@ def test_no_log_if_loss_does_not_improve(hyperopt, caplog) -> None:
assert caplog.record_tuples == [] assert caplog.record_tuples == []
def test_save_trials_saves_trials(mocker, hyperopt, testdatadir, caplog) -> None: def test_save_results_saves_epochs(mocker, hyperopt, testdatadir, caplog) -> None:
trials = create_trials(mocker, hyperopt, testdatadir) epochs = create_results(mocker, hyperopt, testdatadir)
mock_dump = mocker.patch('freqtrade.optimize.hyperopt.dump', return_value=None) mock_dump = mocker.patch('freqtrade.optimize.hyperopt.dump', return_value=None)
trials_file = testdatadir / 'optimize' / 'ut_trials.pickle' results_file = testdatadir / 'optimize' / 'ut_results.pickle'
hyperopt.trials = trials caplog.set_level(logging.DEBUG)
hyperopt.save_trials(final=True)
assert log_has(f"1 epoch saved to '{trials_file}'.", caplog) hyperopt.epochs = epochs
hyperopt._save_results()
assert log_has(f"1 epoch saved to '{results_file}'.", caplog)
mock_dump.assert_called_once() mock_dump.assert_called_once()
hyperopt.trials = trials + trials hyperopt.epochs = epochs + epochs
hyperopt.save_trials(final=True) hyperopt._save_results()
assert log_has(f"2 epochs saved to '{trials_file}'.", caplog) assert log_has(f"2 epochs saved to '{results_file}'.", caplog)
def test_read_trials_returns_trials_file(mocker, hyperopt, testdatadir, caplog) -> None: def test_read_results_returns_epochs(mocker, hyperopt, testdatadir, caplog) -> None:
trials = create_trials(mocker, hyperopt, testdatadir) epochs = create_results(mocker, hyperopt, testdatadir)
mock_load = mocker.patch('freqtrade.optimize.hyperopt.load', return_value=trials) mock_load = mocker.patch('freqtrade.optimize.hyperopt.load', return_value=epochs)
trials_file = testdatadir / 'optimize' / 'ut_trials.pickle' results_file = testdatadir / 'optimize' / 'ut_results.pickle'
hyperopt_trial = hyperopt._read_trials(trials_file) hyperopt_epochs = hyperopt._read_results(results_file)
assert log_has(f"Reading Trials from '{trials_file}'", caplog) assert log_has(f"Reading epochs from '{results_file}'", caplog)
assert hyperopt_trial == trials assert hyperopt_epochs == epochs
mock_load.assert_called_once() mock_load.assert_called_once()