diff --git a/freqtrade/freqai/prediction_models/BaseClassifierModel.py b/freqtrade/freqai/prediction_models/BaseClassifierModel.py index e51e26e0f..291bacc82 100644 --- a/freqtrade/freqai/prediction_models/BaseClassifierModel.py +++ b/freqtrade/freqai/prediction_models/BaseClassifierModel.py @@ -21,7 +21,7 @@ class BaseClassifierModel(IFreqaiModel): """ def train( - self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen + self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs ) -> Any: """ Filter the training data and train a model to it. Train makes heavy use of the datakitchen @@ -68,7 +68,7 @@ class BaseClassifierModel(IFreqaiModel): return model def predict( - self, unfiltered_dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = False + self, dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = False, **kwargs ) -> Tuple[DataFrame, npt.NDArray[np.int_]]: """ Filter the prediction features data and predict with it. @@ -79,9 +79,9 @@ class BaseClassifierModel(IFreqaiModel): data (NaNs) or felt uncertain about data (PCA and DI index) """ - dk.find_features(unfiltered_dataframe) + dk.find_features(dataframe) filtered_dataframe, _ = dk.filter_features( - unfiltered_dataframe, dk.training_features_list, training_filter=False + dataframe, dk.training_features_list, training_filter=False ) filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe) dk.data_dictionary["prediction_features"] = filtered_dataframe diff --git a/freqtrade/freqai/prediction_models/BaseRegressionModel.py b/freqtrade/freqai/prediction_models/BaseRegressionModel.py index 45f0c2937..da6fba571 100644 --- a/freqtrade/freqai/prediction_models/BaseRegressionModel.py +++ b/freqtrade/freqai/prediction_models/BaseRegressionModel.py @@ -20,7 +20,7 @@ class BaseRegressionModel(IFreqaiModel): """ def train( - self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen + self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs ) -> Any: """ Filter the training data and train a model to it. Train makes heavy use of the datakitchen @@ -67,7 +67,7 @@ class BaseRegressionModel(IFreqaiModel): return model def predict( - self, unfiltered_dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = False + self, dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = False, **kwargs ) -> Tuple[DataFrame, npt.NDArray[np.int_]]: """ Filter the prediction features data and predict with it. @@ -78,9 +78,9 @@ class BaseRegressionModel(IFreqaiModel): data (NaNs) or felt uncertain about data (PCA and DI index) """ - dk.find_features(unfiltered_dataframe) + dk.find_features(dataframe) filtered_dataframe, _ = dk.filter_features( - unfiltered_dataframe, dk.training_features_list, training_filter=False + dataframe, dk.training_features_list, training_filter=False ) filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe) dk.data_dictionary["prediction_features"] = filtered_dataframe diff --git a/freqtrade/freqai/prediction_models/BaseTensorFlowModel.py b/freqtrade/freqai/prediction_models/BaseTensorFlowModel.py index 66e6ec1fc..6fb49239b 100644 --- a/freqtrade/freqai/prediction_models/BaseTensorFlowModel.py +++ b/freqtrade/freqai/prediction_models/BaseTensorFlowModel.py @@ -17,7 +17,7 @@ class BaseTensorFlowModel(IFreqaiModel): """ def train( - self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen + self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs ) -> Any: """ Filter the training data and train a model to it. Train makes heavy use of the datakitchen diff --git a/freqtrade/freqai/prediction_models/XGBoostRegressor.py b/freqtrade/freqai/prediction_models/XGBoostRegressor.py new file mode 100644 index 000000000..a8f250d16 --- /dev/null +++ b/freqtrade/freqai/prediction_models/XGBoostRegressor.py @@ -0,0 +1,46 @@ +import logging +from typing import Any, Dict + +import xgboost as xgb + +from freqtrade.freqai.data_kitchen import FreqaiDataKitchen +from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel + + +logger = logging.getLogger(__name__) + + +class XGBoostRegressor(BaseRegressionModel): + """ + User created prediction model. The class needs to override three necessary + functions, predict(), train(), fit(). The class inherits ModelHandler which + has its own DataHandler where data is held, saved, loaded, and managed. + """ + + def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: + """ + User sets up the training and test data to fit their desired model here + :param data_dictionary: the dictionary constructed by DataHandler to hold + all the training and test data/labels. + """ + + xgb.set_config(verbosity=2) + xgb.config_context(verbosity=2) + + X = data_dictionary["train_features"] + y = data_dictionary["train_labels"] + + if self.freqai_info.get("data_split_parameters", {}).get("test_size", 0.1) == 0: + eval_set = None + else: + eval_set = [(data_dictionary["test_features"], data_dictionary["test_labels"])] + + sample_weight = data_dictionary["train_weights"] + + xgb_model = self.get_init_model(dk.pair) + + model = xgb.XGBRegressor(**self.model_training_parameters) + + model.fit(X=X, y=y, sample_weight=sample_weight, eval_set=eval_set, xgb_model=xgb_model) + + return model diff --git a/requirements-freqai.txt b/requirements-freqai.txt index 26e4617af..e8d950382 100644 --- a/requirements-freqai.txt +++ b/requirements-freqai.txt @@ -6,3 +6,4 @@ scikit-learn==1.1.2 joblib==1.1.0 catboost==1.0.6; platform_machine != 'aarch64' lightgbm==3.3.2 +xgboost==1.6.2 diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index 5441b3c24..7783c00e7 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -172,6 +172,37 @@ def test_train_model_in_series_LightGBMClassifier(mocker, freqai_conf): shutil.rmtree(Path(freqai.dk.full_path)) +def test_train_model_in_series_XGBoostRegressor(mocker, freqai_conf): + freqai_conf.update({"timerange": "20180110-20180130"}) + freqai_conf.update({"freqaimodel": "XGBoostRegressor"}) + freqai_conf.update({"strategy": "freqai_test_strat"}) + + strategy = get_patched_freqai_strategy(mocker, freqai_conf) + exchange = get_patched_exchange(mocker, freqai_conf) + strategy.dp = DataProvider(freqai_conf, exchange) + strategy.freqai_info = freqai_conf.get("freqai", {}) + freqai = strategy.freqai + freqai.live = True + freqai.dk = FreqaiDataKitchen(freqai_conf) + timerange = TimeRange.parse_timerange("20180110-20180130") + freqai.dd.load_all_pair_histories(timerange, freqai.dk) + + freqai.dd.pair_dict = MagicMock() + + data_load_timerange = TimeRange.parse_timerange("20180110-20180130") + new_timerange = TimeRange.parse_timerange("20180120-20180130") + + freqai.train_model_in_series(new_timerange, "ADA/BTC", + strategy, freqai.dk, data_load_timerange) + + assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").is_file() + assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").is_file() + assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").is_file() + assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").is_file() + + shutil.rmtree(Path(freqai.dk.full_path)) + + def test_start_backtesting(mocker, freqai_conf): freqai_conf.update({"timerange": "20180120-20180130"}) freqai_conf.get("freqai", {}).update({"save_backtest_models": True})