Add XGBoostRegressor for freqAI, fix mypy errors
This commit is contained in:
parent
4c9ac6b7c0
commit
1b6410d7d1
@ -21,7 +21,7 @@ class BaseClassifierModel(IFreqaiModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen
|
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Filter the training data and train a model to it. Train makes heavy use of the datakitchen
|
Filter the training data and train a model to it. Train makes heavy use of the datakitchen
|
||||||
@ -68,7 +68,7 @@ class BaseClassifierModel(IFreqaiModel):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self, unfiltered_dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = False
|
self, dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = False, **kwargs
|
||||||
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
||||||
"""
|
"""
|
||||||
Filter the prediction features data and predict with it.
|
Filter the prediction features data and predict with it.
|
||||||
@ -79,9 +79,9 @@ class BaseClassifierModel(IFreqaiModel):
|
|||||||
data (NaNs) or felt uncertain about data (PCA and DI index)
|
data (NaNs) or felt uncertain about data (PCA and DI index)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dk.find_features(unfiltered_dataframe)
|
dk.find_features(dataframe)
|
||||||
filtered_dataframe, _ = dk.filter_features(
|
filtered_dataframe, _ = dk.filter_features(
|
||||||
unfiltered_dataframe, dk.training_features_list, training_filter=False
|
dataframe, dk.training_features_list, training_filter=False
|
||||||
)
|
)
|
||||||
filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe)
|
filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe)
|
||||||
dk.data_dictionary["prediction_features"] = filtered_dataframe
|
dk.data_dictionary["prediction_features"] = filtered_dataframe
|
||||||
|
@ -20,7 +20,7 @@ class BaseRegressionModel(IFreqaiModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen
|
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Filter the training data and train a model to it. Train makes heavy use of the datakitchen
|
Filter the training data and train a model to it. Train makes heavy use of the datakitchen
|
||||||
@ -67,7 +67,7 @@ class BaseRegressionModel(IFreqaiModel):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self, unfiltered_dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = False
|
self, dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = False, **kwargs
|
||||||
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
|
||||||
"""
|
"""
|
||||||
Filter the prediction features data and predict with it.
|
Filter the prediction features data and predict with it.
|
||||||
@ -78,9 +78,9 @@ class BaseRegressionModel(IFreqaiModel):
|
|||||||
data (NaNs) or felt uncertain about data (PCA and DI index)
|
data (NaNs) or felt uncertain about data (PCA and DI index)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dk.find_features(unfiltered_dataframe)
|
dk.find_features(dataframe)
|
||||||
filtered_dataframe, _ = dk.filter_features(
|
filtered_dataframe, _ = dk.filter_features(
|
||||||
unfiltered_dataframe, dk.training_features_list, training_filter=False
|
dataframe, dk.training_features_list, training_filter=False
|
||||||
)
|
)
|
||||||
filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe)
|
filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe)
|
||||||
dk.data_dictionary["prediction_features"] = filtered_dataframe
|
dk.data_dictionary["prediction_features"] = filtered_dataframe
|
||||||
|
@ -17,7 +17,7 @@ class BaseTensorFlowModel(IFreqaiModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen
|
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Filter the training data and train a model to it. Train makes heavy use of the datakitchen
|
Filter the training data and train a model to it. Train makes heavy use of the datakitchen
|
||||||
|
46
freqtrade/freqai/prediction_models/XGBoostRegressor.py
Normal file
46
freqtrade/freqai/prediction_models/XGBoostRegressor.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import xgboost as xgb
|
||||||
|
|
||||||
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
|
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class XGBoostRegressor(BaseRegressionModel):
|
||||||
|
"""
|
||||||
|
User created prediction model. The class needs to override three necessary
|
||||||
|
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||||
|
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
User sets up the training and test data to fit their desired model here
|
||||||
|
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||||
|
all the training and test data/labels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
xgb.set_config(verbosity=2)
|
||||||
|
xgb.config_context(verbosity=2)
|
||||||
|
|
||||||
|
X = data_dictionary["train_features"]
|
||||||
|
y = data_dictionary["train_labels"]
|
||||||
|
|
||||||
|
if self.freqai_info.get("data_split_parameters", {}).get("test_size", 0.1) == 0:
|
||||||
|
eval_set = None
|
||||||
|
else:
|
||||||
|
eval_set = [(data_dictionary["test_features"], data_dictionary["test_labels"])]
|
||||||
|
|
||||||
|
sample_weight = data_dictionary["train_weights"]
|
||||||
|
|
||||||
|
xgb_model = self.get_init_model(dk.pair)
|
||||||
|
|
||||||
|
model = xgb.XGBRegressor(**self.model_training_parameters)
|
||||||
|
|
||||||
|
model.fit(X=X, y=y, sample_weight=sample_weight, eval_set=eval_set, xgb_model=xgb_model)
|
||||||
|
|
||||||
|
return model
|
@ -6,3 +6,4 @@ scikit-learn==1.1.2
|
|||||||
joblib==1.1.0
|
joblib==1.1.0
|
||||||
catboost==1.0.6; platform_machine != 'aarch64'
|
catboost==1.0.6; platform_machine != 'aarch64'
|
||||||
lightgbm==3.3.2
|
lightgbm==3.3.2
|
||||||
|
xgboost==1.6.2
|
||||||
|
@ -172,6 +172,37 @@ def test_train_model_in_series_LightGBMClassifier(mocker, freqai_conf):
|
|||||||
shutil.rmtree(Path(freqai.dk.full_path))
|
shutil.rmtree(Path(freqai.dk.full_path))
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_model_in_series_XGBoostRegressor(mocker, freqai_conf):
|
||||||
|
freqai_conf.update({"timerange": "20180110-20180130"})
|
||||||
|
freqai_conf.update({"freqaimodel": "XGBoostRegressor"})
|
||||||
|
freqai_conf.update({"strategy": "freqai_test_strat"})
|
||||||
|
|
||||||
|
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
||||||
|
exchange = get_patched_exchange(mocker, freqai_conf)
|
||||||
|
strategy.dp = DataProvider(freqai_conf, exchange)
|
||||||
|
strategy.freqai_info = freqai_conf.get("freqai", {})
|
||||||
|
freqai = strategy.freqai
|
||||||
|
freqai.live = True
|
||||||
|
freqai.dk = FreqaiDataKitchen(freqai_conf)
|
||||||
|
timerange = TimeRange.parse_timerange("20180110-20180130")
|
||||||
|
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
|
||||||
|
|
||||||
|
freqai.dd.pair_dict = MagicMock()
|
||||||
|
|
||||||
|
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
||||||
|
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
||||||
|
|
||||||
|
freqai.train_model_in_series(new_timerange, "ADA/BTC",
|
||||||
|
strategy, freqai.dk, data_load_timerange)
|
||||||
|
|
||||||
|
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").is_file()
|
||||||
|
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").is_file()
|
||||||
|
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").is_file()
|
||||||
|
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").is_file()
|
||||||
|
|
||||||
|
shutil.rmtree(Path(freqai.dk.full_path))
|
||||||
|
|
||||||
|
|
||||||
def test_start_backtesting(mocker, freqai_conf):
|
def test_start_backtesting(mocker, freqai_conf):
|
||||||
freqai_conf.update({"timerange": "20180120-20180130"})
|
freqai_conf.update({"timerange": "20180120-20180130"})
|
||||||
freqai_conf.get("freqai", {}).update({"save_backtest_models": True})
|
freqai_conf.get("freqai", {}).update({"save_backtest_models": True})
|
||||||
|
Loading…
Reference in New Issue
Block a user