diff --git a/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py b/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py new file mode 100644 index 000000000..5283501d1 --- /dev/null +++ b/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py @@ -0,0 +1,43 @@ +import logging +from typing import Any, Dict + +from sklearn.multioutput import MultiOutputRegressor +from xgboost import XGBRegressor + +from freqtrade.freqai.data_kitchen import FreqaiDataKitchen +from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel + + +logger = logging.getLogger(__name__) + + +class XGBoostRegressorMultiTarget(BaseRegressionModel): + """ + User created prediction model. The class needs to override three necessary + functions, predict(), train(), fit(). The class inherits ModelHandler which + has its own DataHandler where data is held, saved, loaded, and managed. + """ + + def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: + """ + User sets up the training and test data to fit their desired model here + :param data_dictionary: the dictionary constructed by DataHandler to hold + all the training and test data/labels. + """ + + xgb = XGBRegressor(**self.model_training_parameters) + + X = data_dictionary["train_features"] + y = data_dictionary["train_labels"] + eval_set = (data_dictionary["test_features"], data_dictionary["test_labels"]) + sample_weight = data_dictionary["train_weights"] + + if self.continual_learning: + logger.warning('Continual learning not supported for MultiTarget models') + + model = MultiOutputRegressor(estimator=xgb) + model.fit(X=X, y=y, sample_weight=sample_weight) # , eval_set=eval_set) + train_score = model.score(X, y) + test_score = model.score(*eval_set) + logger.info(f"Train score {train_score}, Test score {test_score}") + return model diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index 7783c00e7..ff0eb24a9 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -203,6 +203,37 @@ def test_train_model_in_series_XGBoostRegressor(mocker, freqai_conf): shutil.rmtree(Path(freqai.dk.full_path)) +def test_train_model_in_series_XGBoostRegressorMultiModel(mocker, freqai_conf): + freqai_conf.update({"timerange": "20180110-20180130"}) + freqai_conf.update({"freqaimodel": "XGBoostRegressorMultiTarget"}) + freqai_conf.update({"strategy": "freqai_test_multimodel_strat"}) + strategy = get_patched_freqai_strategy(mocker, freqai_conf) + exchange = get_patched_exchange(mocker, freqai_conf) + strategy.dp = DataProvider(freqai_conf, exchange) + strategy.freqai_info = freqai_conf.get("freqai", {}) + freqai = strategy.freqai + freqai.live = True + freqai.dk = FreqaiDataKitchen(freqai_conf) + timerange = TimeRange.parse_timerange("20180110-20180130") + freqai.dd.load_all_pair_histories(timerange, freqai.dk) + + freqai.dd.pair_dict = MagicMock() + + data_load_timerange = TimeRange.parse_timerange("20180110-20180130") + new_timerange = TimeRange.parse_timerange("20180120-20180130") + + freqai.train_model_in_series(new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange) + + assert len(freqai.dk.label_list) == 2 + assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").is_file() + assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").is_file() + assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").is_file() + assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").is_file() + assert len(freqai.dk.data['training_features_list']) == 26 + + shutil.rmtree(Path(freqai.dk.full_path)) + + def test_start_backtesting(mocker, freqai_conf): freqai_conf.update({"timerange": "20180120-20180130"}) freqai_conf.get("freqai", {}).update({"save_backtest_models": True})