diff --git a/freqtrade/freqai/freqai_interface.py b/freqtrade/freqai/freqai_interface.py index b6f3d8ebc..101df88ec 100644 --- a/freqtrade/freqai/freqai_interface.py +++ b/freqtrade/freqai/freqai_interface.py @@ -661,11 +661,20 @@ class IFreqaiModel(ABC): self.train_time = 0 return + def get_init_model(self, pair: str) -> Any: + if pair not in self.dd.model_dictionary or not self.continual_learning: + init_model = None + else: + init_model = self.dd.model_dictionary[pair] + + return init_model + # Following methods which are overridden by user made prediction models. # See freqai/prediction_models/CatboostPredictionModel.py for an example. @abstractmethod - def train(self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen) -> Any: + def train(self, unfiltered_dataframe: DataFrame, pair: str, + dk: FreqaiDataKitchen, **kwargs) -> Any: """ Filter the training data and train a model to it. Train makes heavy use of the datahandler for storing, saving, loading, and analyzing the data. @@ -675,7 +684,7 @@ class IFreqaiModel(ABC): """ @abstractmethod - def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen) -> Any: + def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs) -> Any: """ Most regressors use the same function names and arguments e.g. user can drop in LGBMRegressor in place of CatBoostRegressor and all data @@ -688,7 +697,7 @@ class IFreqaiModel(ABC): @abstractmethod def predict( - self, dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = True + self, dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = True, **kwargs ) -> Tuple[DataFrame, NDArray[np.int_]]: """ Filter the prediction features data and predict with it. diff --git a/freqtrade/freqai/prediction_models/CatboostClassifier.py b/freqtrade/freqai/prediction_models/CatboostClassifier.py index 13395879a..cd7afd392 100644 --- a/freqtrade/freqai/prediction_models/CatboostClassifier.py +++ b/freqtrade/freqai/prediction_models/CatboostClassifier.py @@ -2,6 +2,7 @@ import logging from typing import Any, Dict from catboost import CatBoostClassifier, Pool + from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel @@ -16,7 +17,7 @@ class CatboostClassifier(BaseClassifierModel): has its own DataHandler where data is held, saved, loaded, and managed. """ - def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: + def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here :params: @@ -36,10 +37,7 @@ class CatboostClassifier(BaseClassifierModel): **self.model_training_parameters, ) - if dk.pair not in self.dd.model_dictionary or not self.continual_learning: - init_model = None - else: - init_model = self.dd.model_dictionary[dk.pair] + init_model = self.get_init_model(dk.pair) cbr.fit(train_data, init_model=init_model) diff --git a/freqtrade/freqai/prediction_models/CatboostRegressor.py b/freqtrade/freqai/prediction_models/CatboostRegressor.py index 0b8bc162b..1ce31b628 100644 --- a/freqtrade/freqai/prediction_models/CatboostRegressor.py +++ b/freqtrade/freqai/prediction_models/CatboostRegressor.py @@ -1,10 +1,9 @@ -import gc import logging from typing import Any, Dict from catboost import CatBoostRegressor, Pool -from freqtrade.freqai.data_kitchen import FreqaiDataKitchen +from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel @@ -18,7 +17,7 @@ class CatboostRegressor(BaseRegressionModel): has its own DataHandler where data is held, saved, loaded, and managed. """ - def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: + def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here :param data_dictionary: the dictionary constructed by DataHandler to hold @@ -39,10 +38,7 @@ class CatboostRegressor(BaseRegressionModel): weight=data_dictionary["test_weights"], ) - if dk.pair not in self.dd.model_dictionary or not self.continual_learning: - init_model = None - else: - init_model = self.dd.model_dictionary[dk.pair] + init_model = self.get_init_model(dk.pair) model = CatBoostRegressor( allow_writing_files=False, diff --git a/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py b/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py index 9ed61488c..bc52bfdd9 100644 --- a/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py +++ b/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py @@ -3,6 +3,7 @@ from typing import Any, Dict from catboost import CatBoostRegressor # , Pool from sklearn.multioutput import MultiOutputRegressor + from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel @@ -17,7 +18,7 @@ class CatboostRegressorMultiTarget(BaseRegressionModel): has its own DataHandler where data is held, saved, loaded, and managed. """ - def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: + def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here :param data_dictionary: the dictionary constructed by DataHandler to hold diff --git a/freqtrade/freqai/prediction_models/LightGBMClassifier.py b/freqtrade/freqai/prediction_models/LightGBMClassifier.py index 0023a9f69..69867eae3 100644 --- a/freqtrade/freqai/prediction_models/LightGBMClassifier.py +++ b/freqtrade/freqai/prediction_models/LightGBMClassifier.py @@ -3,8 +3,9 @@ from typing import Any, Dict from lightgbm import LGBMClassifier -from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel from freqtrade.freqai.data_kitchen import FreqaiDataKitchen +from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel + logger = logging.getLogger(__name__) @@ -16,7 +17,7 @@ class LightGBMClassifier(BaseClassifierModel): has its own DataHandler where data is held, saved, loaded, and managed. """ - def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: + def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here :params: @@ -35,10 +36,7 @@ class LightGBMClassifier(BaseClassifierModel): y = data_dictionary["train_labels"].to_numpy()[:, 0] train_weights = data_dictionary["train_weights"] - if dk.pair not in self.dd.model_dictionary or not self.continual_learning: - init_model = None - else: - init_model = self.dd.model_dictionary[dk.pair] + init_model = self.get_init_model(dk.pair) model = LGBMClassifier(**self.model_training_parameters) diff --git a/freqtrade/freqai/prediction_models/LightGBMRegressor.py b/freqtrade/freqai/prediction_models/LightGBMRegressor.py index 81f0e6d22..99e9ff887 100644 --- a/freqtrade/freqai/prediction_models/LightGBMRegressor.py +++ b/freqtrade/freqai/prediction_models/LightGBMRegressor.py @@ -3,8 +3,9 @@ from typing import Any, Dict from lightgbm import LGBMRegressor -from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel from freqtrade.freqai.data_kitchen import FreqaiDataKitchen +from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel + logger = logging.getLogger(__name__) @@ -16,7 +17,7 @@ class LightGBMRegressor(BaseRegressionModel): has its own DataHandler where data is held, saved, loaded, and managed. """ - def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: + def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ Most regressors use the same function names and arguments e.g. user can drop in LGBMRegressor in place of CatBoostRegressor and all data @@ -35,10 +36,7 @@ class LightGBMRegressor(BaseRegressionModel): y = data_dictionary["train_labels"] train_weights = data_dictionary["train_weights"] - if dk.pair not in self.dd.model_dictionary or not self.continual_learning: - init_model = None - else: - init_model = self.dd.model_dictionary[dk.pair] + init_model = self.get_init_model(dk.pair) model = LGBMRegressor(**self.model_training_parameters) diff --git a/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py b/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py index 2b25493e0..c34680dbe 100644 --- a/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py +++ b/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py @@ -4,8 +4,9 @@ from typing import Any, Dict from lightgbm import LGBMRegressor from sklearn.multioutput import MultiOutputRegressor -from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel from freqtrade.freqai.data_kitchen import FreqaiDataKitchen +from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel + logger = logging.getLogger(__name__) @@ -17,7 +18,7 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel): has its own DataHandler where data is held, saved, loaded, and managed. """ - def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any: + def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ User sets up the training and test data to fit their desired model here :param data_dictionary: the dictionary constructed by DataHandler to hold