move live retraining to separate thread.

This commit is contained in:
robcaulk 2022-05-19 21:15:58 +02:00
parent 1fae6c9ef7
commit c5ecf94177
2 changed files with 74 additions and 17 deletions

View File

@ -1,5 +1,8 @@
# import contextlib
import gc
import logging
# import sys
import threading
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, Tuple
@ -16,6 +19,24 @@ from freqtrade.strategy.interface import IStrategy
pd.options.mode.chained_assignment = None
logger = logging.getLogger(__name__)
# FIXME: suppress stdout for background training
# class DummyFile(object):
# def write(self, x): pass
# @contextlib.contextmanager
# def nostdout():
# save_stdout = sys.stdout
# sys.stdout = DummyFile()
# yield
# sys.stdout = save_stdout
def threaded(fn):
def wrapper(*args, **kwargs):
threading.Thread(target=fn, args=args, kwargs=kwargs).start()
return wrapper
class IFreqaiModel(ABC):
"""
@ -39,6 +60,8 @@ class IFreqaiModel(ABC):
self.current_time = None
self.model = None
self.predictions = None
self.training_on_separate_thread = False
self.retrain = False
def start(self, dataframe: DataFrame, metadata: dict, strategy: IStrategy) -> DataFrame:
"""
@ -122,25 +145,26 @@ class IFreqaiModel(ABC):
training_timerange=self.freqai_info[
'live_trained_timerange'])
(retrain,
new_trained_timerange) = self.dh.check_if_new_training_required(self.freqai_info[
if not self.training_on_separate_thread:
# this will also prevent other pairs from trying to train simultaneously.
(self.retrain,
new_trained_timerange) = self.dh.check_if_new_training_required(self.freqai_info[
'live_trained_timerange'],
metadata)
metadata)
else:
logger.info("FreqAI training a new model on background thread.")
self.retrain = False
if retrain or not file_exists:
self.dh.download_new_data_for_retraining(new_trained_timerange, metadata)
corr_dataframes, base_dataframes = self.dh.load_pairs_histories(new_trained_timerange,
metadata)
unfiltered_dataframe = self.dh.use_strategy_to_populate_indicators(strategy,
corr_dataframes,
base_dataframes,
metadata)
self.model = self.train(unfiltered_dataframe, metadata)
self.dh.save_data(self.model)
if self.retrain or not file_exists:
self.training_on_separate_thread = True # acts like a lock
self.retrain_model_on_separate_thread(new_trained_timerange, metadata, strategy)
self.model = self.dh.load_data()
strategy_provided_features = self.dh.find_features(dataframe)
if strategy_provided_features != self.dh.training_features_list:
self.train_model_in_series(new_trained_timerange, metadata, strategy)
preds, do_preds = self.predict(dataframe, metadata)
self.dh.append_predictions(preds, do_preds, len(dataframe))
@ -206,3 +230,38 @@ class IFreqaiModel(ABC):
else:
logger.info("Could not find model at %s", self.dh.model_path / self.dh.model_filename)
return file_exists
@threaded
def retrain_model_on_separate_thread(self, new_trained_timerange: str, metadata: dict,
strategy: IStrategy):
# with nostdout():
self.dh.download_new_data_for_retraining(new_trained_timerange, metadata)
corr_dataframes, base_dataframes = self.dh.load_pairs_histories(new_trained_timerange,
metadata)
unfiltered_dataframe = self.dh.use_strategy_to_populate_indicators(strategy,
corr_dataframes,
base_dataframes,
metadata)
self.model = self.train(unfiltered_dataframe, metadata)
self.dh.save_data(self.model)
self.training_on_separate_thread = False
self.retrain = False
def train_model_in_series(self, new_trained_timerange: str, metadata: dict,
strategy: IStrategy):
self.dh.download_new_data_for_retraining(new_trained_timerange, metadata)
corr_dataframes, base_dataframes = self.dh.load_pairs_histories(new_trained_timerange,
metadata)
unfiltered_dataframe = self.dh.use_strategy_to_populate_indicators(strategy,
corr_dataframes,
base_dataframes,
metadata)
self.model = self.train(unfiltered_dataframe, metadata)
self.dh.save_data(self.model)

View File

@ -144,8 +144,6 @@ class FreqaiExampleStrategy(IStrategy):
self.freqai_info = self.config["freqai"]
self.pair = metadata['pair']
print("Populating indicators...")
# the following loops are necessary for building the features
# indicated by the user in the configuration file.
for tf in self.freqai_info["timeframes"]: