diff --git a/freqtrade/freqai/data_drawer.py b/freqtrade/freqai/data_drawer.py index 85a84160f..5ef9f534d 100644 --- a/freqtrade/freqai/data_drawer.py +++ b/freqtrade/freqai/data_drawer.py @@ -226,6 +226,7 @@ class FreqaiDataDrawer: historical candles, and also stores historical predictions despite retrainings (so stored predictions are true predictions, not just inferencing on trained data) """ + # dynamic df returned to strategy and plotted in frequi mrv_df = self.model_return_values[pair] = pd.DataFrame() @@ -246,6 +247,8 @@ class FreqaiDataDrawer: else: for label in dk.label_list: mrv_df[label] = pred_df[label] + if mrv_df[label].dtype == object: + continue mrv_df[f"{label}_mean"] = dk.data["labels_mean"][label] mrv_df[f"{label}_std"] = dk.data["labels_std"][label] @@ -295,6 +298,8 @@ class FreqaiDataDrawer: for label in dk.label_list: df[label].iloc[-1] = predictions[label].iloc[-1] + if df[label].dtype == object: + continue df[f"{label}_mean"].iloc[-1] = dk.data["labels_mean"][label] df[f"{label}_std"].iloc[-1] = dk.data["labels_std"][label] diff --git a/freqtrade/freqai/data_kitchen.py b/freqtrade/freqai/data_kitchen.py index 6fb75374d..eb3955f65 100644 --- a/freqtrade/freqai/data_kitchen.py +++ b/freqtrade/freqai/data_kitchen.py @@ -294,7 +294,7 @@ class FreqaiDataKitchen: self.data[item + "_min"] = train_min[item] for item in data_dictionary["train_labels"].keys(): - if data_dictionary["train_labels"][item].dtype == str: + if data_dictionary["train_labels"][item].dtype == object: continue train_labels_max = data_dictionary["train_labels"][item].max() train_labels_min = data_dictionary["train_labels"][item].min() @@ -1010,6 +1010,8 @@ class FreqaiDataKitchen: self.data["labels_mean"], self.data["labels_std"] = {}, {} for label in self.label_list: + if self.data_dictionary["train_labels"][label].dtype == object: + continue f = spy.stats.norm.fit(self.data_dictionary["train_labels"][label]) self.data["labels_mean"][label], self.data["labels_std"][label] = f[0], f[1] diff --git a/freqtrade/freqai/freqai_interface.py b/freqtrade/freqai/freqai_interface.py index 8a2aaeddb..12e0abe97 100644 --- a/freqtrade/freqai/freqai_interface.py +++ b/freqtrade/freqai/freqai_interface.py @@ -123,7 +123,7 @@ class IFreqaiModel(ABC): dataframe = dk.remove_features_from_df(dk.return_dataframe) del dk - return self.return_values(dataframe) + return dataframe @threaded def start_scanning(self, strategy: IStrategy) -> None: @@ -609,17 +609,6 @@ class IFreqaiModel(ABC): data (NaNs) or felt uncertain about data (i.e. SVM and/or DI index) """ - @abstractmethod - def return_values(self, dataframe: DataFrame) -> DataFrame: - """ - User defines the dataframe to be returned to strategy here. - :param dataframe: DataFrame = the full dataframe for the current prediction (live) - or --timerange (backtesting) - :return: dataframe: DataFrame = dataframe filled with user defined data - """ - - return - def analyze_trade_database(self, dk: FreqaiDataKitchen, pair: str) -> None: """ User analyzes the trade database here and returns summary stats which will be passed back diff --git a/freqtrade/freqai/prediction_models/BaseRegressionModel.py b/freqtrade/freqai/prediction_models/BaseRegressionModel.py index 112e48183..a3bd82a8f 100644 --- a/freqtrade/freqai/prediction_models/BaseRegressionModel.py +++ b/freqtrade/freqai/prediction_models/BaseRegressionModel.py @@ -19,15 +19,6 @@ class BaseRegressionModel(IFreqaiModel): such as prediction_models/CatboostPredictionModel.py for guidance. """ - def return_values(self, dataframe: DataFrame) -> DataFrame: - """ - User uses this function to add any additional return values to the dataframe. - e.g. - dataframe['volatility'] = dk.volatility_values - """ - - return dataframe - def train( self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen ) -> Any: diff --git a/freqtrade/freqai/prediction_models/BaseTensorFlowModel.py b/freqtrade/freqai/prediction_models/BaseTensorFlowModel.py index d94378494..afb439cbf 100644 --- a/freqtrade/freqai/prediction_models/BaseTensorFlowModel.py +++ b/freqtrade/freqai/prediction_models/BaseTensorFlowModel.py @@ -16,15 +16,6 @@ class BaseTensorFlowModel(IFreqaiModel): User *must* inherit from this class and set fit() and predict(). """ - def return_values(self, dataframe: DataFrame) -> DataFrame: - """ - User uses this function to add any additional return values to the dataframe. - e.g. - dataframe['volatility'] = dk.volatility_values - """ - - return dataframe - def train( self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen ) -> Any: diff --git a/freqtrade/freqai/prediction_models/CatboostClassifier.py b/freqtrade/freqai/prediction_models/CatboostClassifier.py new file mode 100644 index 000000000..d003744fb --- /dev/null +++ b/freqtrade/freqai/prediction_models/CatboostClassifier.py @@ -0,0 +1,44 @@ +import logging +from typing import Any, Dict + +from catboost import CatBoostClassifier, Pool + +from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel + + +logger = logging.getLogger(__name__) + + +class CatboostClassifier(BaseRegressionModel): + """ + User created prediction model. The class needs to override three necessary + functions, predict(), train(), fit(). The class inherits ModelHandler which + has its own DataHandler where data is held, saved, loaded, and managed. + """ + + def fit(self, data_dictionary: Dict) -> Any: + """ + User sets up the training and test data to fit their desired model here + :params: + :data_dictionary: the dictionary constructed by DataHandler to hold + all the training and test data/labels. + """ + + train_data = Pool( + data=data_dictionary["train_features"], + label=data_dictionary["train_labels"], + weight=data_dictionary["train_weights"], + ) + + cbr = CatBoostClassifier( + allow_writing_files=False, + gpu_ram_part=0.5, + verbose=100, + early_stopping_rounds=400, + loss_function='MultiClass', + **self.model_training_parameters, + ) + + cbr.fit(train_data) + + return cbr diff --git a/freqtrade/freqai/prediction_models/CatboostPredictionModel.py b/freqtrade/freqai/prediction_models/CatboostRegressor.py similarity index 97% rename from freqtrade/freqai/prediction_models/CatboostPredictionModel.py rename to freqtrade/freqai/prediction_models/CatboostRegressor.py index 9731e0c01..d93569c91 100644 --- a/freqtrade/freqai/prediction_models/CatboostPredictionModel.py +++ b/freqtrade/freqai/prediction_models/CatboostRegressor.py @@ -10,7 +10,7 @@ from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressio logger = logging.getLogger(__name__) -class CatboostPredictionModel(BaseRegressionModel): +class CatboostRegressor(BaseRegressionModel): """ User created prediction model. The class needs to override three necessary functions, predict(), train(), fit(). The class inherits ModelHandler which diff --git a/freqtrade/freqai/prediction_models/CatboostPredictionMultiModel.py b/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py similarity index 96% rename from freqtrade/freqai/prediction_models/CatboostPredictionMultiModel.py rename to freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py index 35a93e808..9894decd1 100644 --- a/freqtrade/freqai/prediction_models/CatboostPredictionMultiModel.py +++ b/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py @@ -10,7 +10,7 @@ from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressio logger = logging.getLogger(__name__) -class CatboostPredictionMultiModel(BaseRegressionModel): +class CatboostRegressorMultiTarget(BaseRegressionModel): """ User created prediction model. The class needs to override three necessary functions, predict(), train(), fit(). The class inherits ModelHandler which diff --git a/freqtrade/freqai/prediction_models/LightGBMPredictionModel.py b/freqtrade/freqai/prediction_models/LightGBMRegressor.py similarity index 96% rename from freqtrade/freqai/prediction_models/LightGBMPredictionModel.py rename to freqtrade/freqai/prediction_models/LightGBMRegressor.py index c94bc5698..f72792611 100644 --- a/freqtrade/freqai/prediction_models/LightGBMPredictionModel.py +++ b/freqtrade/freqai/prediction_models/LightGBMRegressor.py @@ -9,7 +9,7 @@ from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressio logger = logging.getLogger(__name__) -class LightGBMPredictionModel(BaseRegressionModel): +class LightGBMRegressor(BaseRegressionModel): """ User created prediction model. The class needs to override three necessary functions, predict(), train(), fit(). The class inherits ModelHandler which diff --git a/freqtrade/freqai/prediction_models/LightGBMPredictionMultiModel.py b/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py similarity index 96% rename from freqtrade/freqai/prediction_models/LightGBMPredictionMultiModel.py rename to freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py index 4c51c9008..ecd405369 100644 --- a/freqtrade/freqai/prediction_models/LightGBMPredictionMultiModel.py +++ b/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py @@ -10,7 +10,7 @@ from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressio logger = logging.getLogger(__name__) -class LightGBMPredictionMultiModel(BaseRegressionModel): +class LightGBMRegressorMultiTarget(BaseRegressionModel): """ User created prediction model. The class needs to override three necessary functions, predict(), train(), fit(). The class inherits ModelHandler which diff --git a/tests/freqai/conftest.py b/tests/freqai/conftest.py index ecd8b2d57..90e99951d 100644 --- a/tests/freqai/conftest.py +++ b/tests/freqai/conftest.py @@ -21,7 +21,7 @@ def freqai_conf(default_conf, tmpdir): "strategy": "freqai_test_strat", "user_data_dir": Path(tmpdir), "strategy-path": "freqtrade/tests/strategy/strats", - "freqaimodel": "LightGBMPredictionModel", + "freqaimodel": "LightGBMRegressor", "freqaimodel_path": "freqai/prediction_models", "timerange": "20180110-20180115", "freqai": { diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index c4302e756..68fc14f71 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -43,7 +43,7 @@ def test_train_model_in_series_LightGBM(mocker, freqai_conf): def test_train_model_in_series_LightGBMMultiModel(mocker, freqai_conf): freqai_conf.update({"timerange": "20180110-20180130"}) freqai_conf.update({"strategy": "freqai_test_multimodel_strat"}) - freqai_conf.update({"freqaimodel": "LightGBMPredictionMultiModel"}) + freqai_conf.update({"freqaimodel": "LightGBMRegressorMultiTarget"}) strategy = get_patched_freqai_strategy(mocker, freqai_conf) exchange = get_patched_exchange(mocker, freqai_conf) strategy.dp = DataProvider(freqai_conf, exchange) @@ -73,8 +73,9 @@ def test_train_model_in_series_LightGBMMultiModel(mocker, freqai_conf): @pytest.mark.skipif("arm" in platform.uname()[-1], reason="no ARM for Catboost ...") def test_train_model_in_series_Catboost(mocker, freqai_conf): freqai_conf.update({"timerange": "20180110-20180130"}) - freqai_conf.update({"freqaimodel": "CatboostPredictionModel"}) - del freqai_conf['freqai']['model_training_parameters']['verbosity'] + freqai_conf.update({"freqaimodel": "CatboostRegressor"}) + freqai_conf.get('freqai', {}).update( + {'model_training_parameters': {"n_estimators": 100, "verbose": 0}}) strategy = get_patched_freqai_strategy(mocker, freqai_conf) exchange = get_patched_exchange(mocker, freqai_conf) strategy.dp = DataProvider(freqai_conf, exchange)