From 97077ba18acd9ea0ad67ba45e917aca6bdcb3b0d Mon Sep 17 00:00:00 2001 From: robcaulk Date: Tue, 6 Sep 2022 20:30:37 +0200 Subject: [PATCH] add continual learning to catboost and friends --- docs/freqai.md | 1 + freqtrade/freqai/freqai_interface.py | 3 ++- .../prediction_models/BaseClassifierModel.py | 2 +- .../prediction_models/BaseRegressionModel.py | 2 +- .../prediction_models/BaseTensorFlowModel.py | 2 +- .../prediction_models/CatboostClassifier.py | 11 ++++++++--- .../freqai/prediction_models/CatboostRegressor.py | 15 ++++++++------- .../CatboostRegressorMultiTarget.py | 7 +++++-- .../prediction_models/LightGBMClassifier.py | 11 ++++++++--- .../freqai/prediction_models/LightGBMRegressor.py | 11 ++++++++--- .../LightGBMRegressorMultiTarget.py | 7 +++++-- 11 files changed, 48 insertions(+), 24 deletions(-) diff --git a/docs/freqai.md b/docs/freqai.md index c0844bf32..e790bbb81 100644 --- a/docs/freqai.md +++ b/docs/freqai.md @@ -98,6 +98,7 @@ Mandatory parameters are marked as **Required**, which means that they are requi | `expiration_hours` | Avoid making predictions if a model is more than `expiration_hours` old.
Defaults set to 0, which means models never expire.
**Datatype:** Positive integer. | `fit_live_predictions_candles` | Number of historical candles to use for computing target (label) statistics from prediction data, instead of from the training data set.
**Datatype:** Positive integer. | `follow_mode` | If true, this instance of FreqAI will look for models associated with `identifier` and load those for inferencing. A `follower` will **not** train new models.
**Datatype:** Boolean. Default: `False`. +| `continual_learning` | If true, FreqAI will start training new models from the final state of the most recently trained model.
**Datatype:** Boolean. Default: `False`. | | **Feature parameters** | `feature_parameters` | A dictionary containing the parameters used to engineer the feature set. Details and examples are shown [here](#feature-engineering).
**Datatype:** Dictionary. | `include_timeframes` | A list of timeframes that all indicators in `populate_any_indicators` will be created for. The list is added as features to the base asset feature set.
**Datatype:** List of timeframes (strings). diff --git a/freqtrade/freqai/freqai_interface.py b/freqtrade/freqai/freqai_interface.py index a9c21fb65..b6f3d8ebc 100644 --- a/freqtrade/freqai/freqai_interface.py +++ b/freqtrade/freqai/freqai_interface.py @@ -86,6 +86,7 @@ class IFreqaiModel(ABC): self.begin_time: float = 0 self.begin_time_train: float = 0 self.base_tf_seconds = timeframe_to_seconds(self.config['timeframe']) + self.continual_learning = self.freqai_info.get('continual_learning', False) self._threads: List[threading.Thread] = [] self._stop_event = threading.Event() @@ -674,7 +675,7 @@ class IFreqaiModel(ABC): """ @abstractmethod - def fit(self, data_dictionary: Dict[str, Any]) -> Any: + def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen) -> Any: """ Most regressors use the same function names and arguments e.g. user can drop in LGBMRegressor in place of CatBoostRegressor and all data diff --git a/freqtrade/freqai/prediction_models/BaseClassifierModel.py b/freqtrade/freqai/prediction_models/BaseClassifierModel.py index 2edbf3b51..e51e26e0f 100644 --- a/freqtrade/freqai/prediction_models/BaseClassifierModel.py +++ b/freqtrade/freqai/prediction_models/BaseClassifierModel.py @@ -61,7 +61,7 @@ class BaseClassifierModel(IFreqaiModel): ) logger.info(f'Training model on {len(data_dictionary["train_features"])} data points') - model = self.fit(data_dictionary) + model = self.fit(data_dictionary, dk) logger.info(f"--------------------done training {pair}--------------------") diff --git a/freqtrade/freqai/prediction_models/BaseRegressionModel.py b/freqtrade/freqai/prediction_models/BaseRegressionModel.py index 2ef175a2e..45f0c2937 100644 --- a/freqtrade/freqai/prediction_models/BaseRegressionModel.py +++ b/freqtrade/freqai/prediction_models/BaseRegressionModel.py @@ -60,7 +60,7 @@ class BaseRegressionModel(IFreqaiModel): ) logger.info(f'Training model on {len(data_dictionary["train_features"])} data points') - model = self.fit(data_dictionary) + model = self.fit(data_dictionary, dk) logger.info(f"--------------------done training {pair}--------------------") diff --git a/freqtrade/freqai/prediction_models/BaseTensorFlowModel.py b/freqtrade/freqai/prediction_models/BaseTensorFlowModel.py index 04eff045f..66e6ec1fc 100644 --- a/freqtrade/freqai/prediction_models/BaseTensorFlowModel.py +++ b/freqtrade/freqai/prediction_models/BaseTensorFlowModel.py @@ -57,7 +57,7 @@ class BaseTensorFlowModel(IFreqaiModel): ) logger.info(f'Training model on {len(data_dictionary["train_features"])} data points') - model = self.fit(data_dictionary) + model = self.fit(data_dictionary, dk) logger.info(f"--------------------done training {pair}--------------------") diff --git a/freqtrade/freqai/prediction_models/CatboostClassifier.py b/freqtrade/freqai/prediction_models/CatboostClassifier.py index b88b28b25..13395879a 100644 --- a/freqtrade/freqai/prediction_models/CatboostClassifier.py +++ b/freqtrade/freqai/prediction_models/CatboostClassifier.py @@ -2,7 +2,7 @@ import logging from typing import Any, Dict from catboost import CatBoostClassifier, Pool - +from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel @@ -16,7 +16,7 @@ class CatboostClassifier(BaseClassifierModel): has its own DataHandler where data is held, saved, loaded, and managed. """ - def fit(self, data_dictionary: Dict) -> Any: + def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: """ User sets up the training and test data to fit their desired model here :params: @@ -36,6 +36,11 @@ class CatboostClassifier(BaseClassifierModel): **self.model_training_parameters, ) - cbr.fit(train_data) + if dk.pair not in self.dd.model_dictionary or not self.continual_learning: + init_model = None + else: + init_model = self.dd.model_dictionary[dk.pair] + + cbr.fit(train_data, init_model=init_model) return cbr diff --git a/freqtrade/freqai/prediction_models/CatboostRegressor.py b/freqtrade/freqai/prediction_models/CatboostRegressor.py index d93569c91..0b8bc162b 100644 --- a/freqtrade/freqai/prediction_models/CatboostRegressor.py +++ b/freqtrade/freqai/prediction_models/CatboostRegressor.py @@ -3,6 +3,7 @@ import logging from typing import Any, Dict from catboost import CatBoostRegressor, Pool +from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel @@ -17,7 +18,7 @@ class CatboostRegressor(BaseRegressionModel): has its own DataHandler where data is held, saved, loaded, and managed. """ - def fit(self, data_dictionary: Dict) -> Any: + def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: """ User sets up the training and test data to fit their desired model here :param data_dictionary: the dictionary constructed by DataHandler to hold @@ -38,16 +39,16 @@ class CatboostRegressor(BaseRegressionModel): weight=data_dictionary["test_weights"], ) + if dk.pair not in self.dd.model_dictionary or not self.continual_learning: + init_model = None + else: + init_model = self.dd.model_dictionary[dk.pair] + model = CatBoostRegressor( allow_writing_files=False, **self.model_training_parameters, ) - model.fit(X=train_data, eval_set=test_data) - - # some evidence that catboost pools have memory leaks: - # https://github.com/catboost/catboost/issues/1835 - del train_data, test_data - gc.collect() + model.fit(X=train_data, eval_set=test_data, init_model=init_model) return model diff --git a/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py b/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py index 9894decd1..9ed61488c 100644 --- a/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py +++ b/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py @@ -3,7 +3,7 @@ from typing import Any, Dict from catboost import CatBoostRegressor # , Pool from sklearn.multioutput import MultiOutputRegressor - +from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel @@ -17,7 +17,7 @@ class CatboostRegressorMultiTarget(BaseRegressionModel): has its own DataHandler where data is held, saved, loaded, and managed. """ - def fit(self, data_dictionary: Dict) -> Any: + def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: """ User sets up the training and test data to fit their desired model here :param data_dictionary: the dictionary constructed by DataHandler to hold @@ -34,6 +34,9 @@ class CatboostRegressorMultiTarget(BaseRegressionModel): eval_set = (data_dictionary["test_features"], data_dictionary["test_labels"]) sample_weight = data_dictionary["train_weights"] + if self.continual_learning: + logger.warning('Continual learning not supported for MultiTarget models') + model = MultiOutputRegressor(estimator=cbr) model.fit(X=X, y=y, sample_weight=sample_weight) # , eval_set=eval_set) diff --git a/freqtrade/freqai/prediction_models/LightGBMClassifier.py b/freqtrade/freqai/prediction_models/LightGBMClassifier.py index 4ac2c448b..0023a9f69 100644 --- a/freqtrade/freqai/prediction_models/LightGBMClassifier.py +++ b/freqtrade/freqai/prediction_models/LightGBMClassifier.py @@ -4,7 +4,7 @@ from typing import Any, Dict from lightgbm import LGBMClassifier from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel - +from freqtrade.freqai.data_kitchen import FreqaiDataKitchen logger = logging.getLogger(__name__) @@ -16,7 +16,7 @@ class LightGBMClassifier(BaseClassifierModel): has its own DataHandler where data is held, saved, loaded, and managed. """ - def fit(self, data_dictionary: Dict) -> Any: + def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: """ User sets up the training and test data to fit their desired model here :params: @@ -35,9 +35,14 @@ class LightGBMClassifier(BaseClassifierModel): y = data_dictionary["train_labels"].to_numpy()[:, 0] train_weights = data_dictionary["train_weights"] + if dk.pair not in self.dd.model_dictionary or not self.continual_learning: + init_model = None + else: + init_model = self.dd.model_dictionary[dk.pair] + model = LGBMClassifier(**self.model_training_parameters) model.fit(X=X, y=y, eval_set=eval_set, sample_weight=train_weights, - eval_sample_weight=[test_weights]) + eval_sample_weight=[test_weights], init_model=init_model) return model diff --git a/freqtrade/freqai/prediction_models/LightGBMRegressor.py b/freqtrade/freqai/prediction_models/LightGBMRegressor.py index 2431fd2ad..81f0e6d22 100644 --- a/freqtrade/freqai/prediction_models/LightGBMRegressor.py +++ b/freqtrade/freqai/prediction_models/LightGBMRegressor.py @@ -4,7 +4,7 @@ from typing import Any, Dict from lightgbm import LGBMRegressor from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel - +from freqtrade.freqai.data_kitchen import FreqaiDataKitchen logger = logging.getLogger(__name__) @@ -16,7 +16,7 @@ class LightGBMRegressor(BaseRegressionModel): has its own DataHandler where data is held, saved, loaded, and managed. """ - def fit(self, data_dictionary: Dict) -> Any: + def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: """ Most regressors use the same function names and arguments e.g. user can drop in LGBMRegressor in place of CatBoostRegressor and all data @@ -35,9 +35,14 @@ class LightGBMRegressor(BaseRegressionModel): y = data_dictionary["train_labels"] train_weights = data_dictionary["train_weights"] + if dk.pair not in self.dd.model_dictionary or not self.continual_learning: + init_model = None + else: + init_model = self.dd.model_dictionary[dk.pair] + model = LGBMRegressor(**self.model_training_parameters) model.fit(X=X, y=y, eval_set=eval_set, sample_weight=train_weights, - eval_sample_weight=[eval_weights]) + eval_sample_weight=[eval_weights], init_model=init_model) return model diff --git a/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py b/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py index ecd405369..2b25493e0 100644 --- a/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py +++ b/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py @@ -5,7 +5,7 @@ from lightgbm import LGBMRegressor from sklearn.multioutput import MultiOutputRegressor from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel - +from freqtrade.freqai.data_kitchen import FreqaiDataKitchen logger = logging.getLogger(__name__) @@ -17,7 +17,7 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel): has its own DataHandler where data is held, saved, loaded, and managed. """ - def fit(self, data_dictionary: Dict) -> Any: + def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: """ User sets up the training and test data to fit their desired model here :param data_dictionary: the dictionary constructed by DataHandler to hold @@ -31,6 +31,9 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel): eval_set = (data_dictionary["test_features"], data_dictionary["test_labels"]) sample_weight = data_dictionary["train_weights"] + if self.continual_learning: + logger.warning('Continual learning not supported for MultiTarget models') + model = MultiOutputRegressor(estimator=lgb) model.fit(X=X, y=y, sample_weight=sample_weight) # , eval_set=eval_set) train_score = model.score(X, y)