add classifier, improve model naming scheme
This commit is contained in:
parent
ce8fbbf743
commit
07763d0d4f
@ -226,6 +226,7 @@ class FreqaiDataDrawer:
|
|||||||
historical candles, and also stores historical predictions despite retrainings (so stored
|
historical candles, and also stores historical predictions despite retrainings (so stored
|
||||||
predictions are true predictions, not just inferencing on trained data)
|
predictions are true predictions, not just inferencing on trained data)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# dynamic df returned to strategy and plotted in frequi
|
# dynamic df returned to strategy and plotted in frequi
|
||||||
mrv_df = self.model_return_values[pair] = pd.DataFrame()
|
mrv_df = self.model_return_values[pair] = pd.DataFrame()
|
||||||
|
|
||||||
@ -246,6 +247,8 @@ class FreqaiDataDrawer:
|
|||||||
else:
|
else:
|
||||||
for label in dk.label_list:
|
for label in dk.label_list:
|
||||||
mrv_df[label] = pred_df[label]
|
mrv_df[label] = pred_df[label]
|
||||||
|
if mrv_df[label].dtype == object:
|
||||||
|
continue
|
||||||
mrv_df[f"{label}_mean"] = dk.data["labels_mean"][label]
|
mrv_df[f"{label}_mean"] = dk.data["labels_mean"][label]
|
||||||
mrv_df[f"{label}_std"] = dk.data["labels_std"][label]
|
mrv_df[f"{label}_std"] = dk.data["labels_std"][label]
|
||||||
|
|
||||||
@ -295,6 +298,8 @@ class FreqaiDataDrawer:
|
|||||||
|
|
||||||
for label in dk.label_list:
|
for label in dk.label_list:
|
||||||
df[label].iloc[-1] = predictions[label].iloc[-1]
|
df[label].iloc[-1] = predictions[label].iloc[-1]
|
||||||
|
if df[label].dtype == object:
|
||||||
|
continue
|
||||||
df[f"{label}_mean"].iloc[-1] = dk.data["labels_mean"][label]
|
df[f"{label}_mean"].iloc[-1] = dk.data["labels_mean"][label]
|
||||||
df[f"{label}_std"].iloc[-1] = dk.data["labels_std"][label]
|
df[f"{label}_std"].iloc[-1] = dk.data["labels_std"][label]
|
||||||
|
|
||||||
|
@ -294,7 +294,7 @@ class FreqaiDataKitchen:
|
|||||||
self.data[item + "_min"] = train_min[item]
|
self.data[item + "_min"] = train_min[item]
|
||||||
|
|
||||||
for item in data_dictionary["train_labels"].keys():
|
for item in data_dictionary["train_labels"].keys():
|
||||||
if data_dictionary["train_labels"][item].dtype == str:
|
if data_dictionary["train_labels"][item].dtype == object:
|
||||||
continue
|
continue
|
||||||
train_labels_max = data_dictionary["train_labels"][item].max()
|
train_labels_max = data_dictionary["train_labels"][item].max()
|
||||||
train_labels_min = data_dictionary["train_labels"][item].min()
|
train_labels_min = data_dictionary["train_labels"][item].min()
|
||||||
@ -1010,6 +1010,8 @@ class FreqaiDataKitchen:
|
|||||||
|
|
||||||
self.data["labels_mean"], self.data["labels_std"] = {}, {}
|
self.data["labels_mean"], self.data["labels_std"] = {}, {}
|
||||||
for label in self.label_list:
|
for label in self.label_list:
|
||||||
|
if self.data_dictionary["train_labels"][label].dtype == object:
|
||||||
|
continue
|
||||||
f = spy.stats.norm.fit(self.data_dictionary["train_labels"][label])
|
f = spy.stats.norm.fit(self.data_dictionary["train_labels"][label])
|
||||||
self.data["labels_mean"][label], self.data["labels_std"][label] = f[0], f[1]
|
self.data["labels_mean"][label], self.data["labels_std"][label] = f[0], f[1]
|
||||||
|
|
||||||
|
@ -123,7 +123,7 @@ class IFreqaiModel(ABC):
|
|||||||
|
|
||||||
dataframe = dk.remove_features_from_df(dk.return_dataframe)
|
dataframe = dk.remove_features_from_df(dk.return_dataframe)
|
||||||
del dk
|
del dk
|
||||||
return self.return_values(dataframe)
|
return dataframe
|
||||||
|
|
||||||
@threaded
|
@threaded
|
||||||
def start_scanning(self, strategy: IStrategy) -> None:
|
def start_scanning(self, strategy: IStrategy) -> None:
|
||||||
@ -609,17 +609,6 @@ class IFreqaiModel(ABC):
|
|||||||
data (NaNs) or felt uncertain about data (i.e. SVM and/or DI index)
|
data (NaNs) or felt uncertain about data (i.e. SVM and/or DI index)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def return_values(self, dataframe: DataFrame) -> DataFrame:
|
|
||||||
"""
|
|
||||||
User defines the dataframe to be returned to strategy here.
|
|
||||||
:param dataframe: DataFrame = the full dataframe for the current prediction (live)
|
|
||||||
or --timerange (backtesting)
|
|
||||||
:return: dataframe: DataFrame = dataframe filled with user defined data
|
|
||||||
"""
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
def analyze_trade_database(self, dk: FreqaiDataKitchen, pair: str) -> None:
|
def analyze_trade_database(self, dk: FreqaiDataKitchen, pair: str) -> None:
|
||||||
"""
|
"""
|
||||||
User analyzes the trade database here and returns summary stats which will be passed back
|
User analyzes the trade database here and returns summary stats which will be passed back
|
||||||
|
@ -19,15 +19,6 @@ class BaseRegressionModel(IFreqaiModel):
|
|||||||
such as prediction_models/CatboostPredictionModel.py for guidance.
|
such as prediction_models/CatboostPredictionModel.py for guidance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def return_values(self, dataframe: DataFrame) -> DataFrame:
|
|
||||||
"""
|
|
||||||
User uses this function to add any additional return values to the dataframe.
|
|
||||||
e.g.
|
|
||||||
dataframe['volatility'] = dk.volatility_values
|
|
||||||
"""
|
|
||||||
|
|
||||||
return dataframe
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen
|
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
@ -16,15 +16,6 @@ class BaseTensorFlowModel(IFreqaiModel):
|
|||||||
User *must* inherit from this class and set fit() and predict().
|
User *must* inherit from this class and set fit() and predict().
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def return_values(self, dataframe: DataFrame) -> DataFrame:
|
|
||||||
"""
|
|
||||||
User uses this function to add any additional return values to the dataframe.
|
|
||||||
e.g.
|
|
||||||
dataframe['volatility'] = dk.volatility_values
|
|
||||||
"""
|
|
||||||
|
|
||||||
return dataframe
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen
|
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
44
freqtrade/freqai/prediction_models/CatboostClassifier.py
Normal file
44
freqtrade/freqai/prediction_models/CatboostClassifier.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from catboost import CatBoostClassifier, Pool
|
||||||
|
|
||||||
|
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CatboostClassifier(BaseRegressionModel):
|
||||||
|
"""
|
||||||
|
User created prediction model. The class needs to override three necessary
|
||||||
|
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||||
|
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def fit(self, data_dictionary: Dict) -> Any:
|
||||||
|
"""
|
||||||
|
User sets up the training and test data to fit their desired model here
|
||||||
|
:params:
|
||||||
|
:data_dictionary: the dictionary constructed by DataHandler to hold
|
||||||
|
all the training and test data/labels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
train_data = Pool(
|
||||||
|
data=data_dictionary["train_features"],
|
||||||
|
label=data_dictionary["train_labels"],
|
||||||
|
weight=data_dictionary["train_weights"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cbr = CatBoostClassifier(
|
||||||
|
allow_writing_files=False,
|
||||||
|
gpu_ram_part=0.5,
|
||||||
|
verbose=100,
|
||||||
|
early_stopping_rounds=400,
|
||||||
|
loss_function='MultiClass',
|
||||||
|
**self.model_training_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
cbr.fit(train_data)
|
||||||
|
|
||||||
|
return cbr
|
@ -10,7 +10,7 @@ from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressio
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CatboostPredictionModel(BaseRegressionModel):
|
class CatboostRegressor(BaseRegressionModel):
|
||||||
"""
|
"""
|
||||||
User created prediction model. The class needs to override three necessary
|
User created prediction model. The class needs to override three necessary
|
||||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
@ -10,7 +10,7 @@ from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressio
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CatboostPredictionMultiModel(BaseRegressionModel):
|
class CatboostRegressorMultiTarget(BaseRegressionModel):
|
||||||
"""
|
"""
|
||||||
User created prediction model. The class needs to override three necessary
|
User created prediction model. The class needs to override three necessary
|
||||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
@ -9,7 +9,7 @@ from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressio
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LightGBMPredictionModel(BaseRegressionModel):
|
class LightGBMRegressor(BaseRegressionModel):
|
||||||
"""
|
"""
|
||||||
User created prediction model. The class needs to override three necessary
|
User created prediction model. The class needs to override three necessary
|
||||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
@ -10,7 +10,7 @@ from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressio
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LightGBMPredictionMultiModel(BaseRegressionModel):
|
class LightGBMRegressorMultiTarget(BaseRegressionModel):
|
||||||
"""
|
"""
|
||||||
User created prediction model. The class needs to override three necessary
|
User created prediction model. The class needs to override three necessary
|
||||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
@ -21,7 +21,7 @@ def freqai_conf(default_conf, tmpdir):
|
|||||||
"strategy": "freqai_test_strat",
|
"strategy": "freqai_test_strat",
|
||||||
"user_data_dir": Path(tmpdir),
|
"user_data_dir": Path(tmpdir),
|
||||||
"strategy-path": "freqtrade/tests/strategy/strats",
|
"strategy-path": "freqtrade/tests/strategy/strats",
|
||||||
"freqaimodel": "LightGBMPredictionModel",
|
"freqaimodel": "LightGBMRegressor",
|
||||||
"freqaimodel_path": "freqai/prediction_models",
|
"freqaimodel_path": "freqai/prediction_models",
|
||||||
"timerange": "20180110-20180115",
|
"timerange": "20180110-20180115",
|
||||||
"freqai": {
|
"freqai": {
|
||||||
|
@ -43,7 +43,7 @@ def test_train_model_in_series_LightGBM(mocker, freqai_conf):
|
|||||||
def test_train_model_in_series_LightGBMMultiModel(mocker, freqai_conf):
|
def test_train_model_in_series_LightGBMMultiModel(mocker, freqai_conf):
|
||||||
freqai_conf.update({"timerange": "20180110-20180130"})
|
freqai_conf.update({"timerange": "20180110-20180130"})
|
||||||
freqai_conf.update({"strategy": "freqai_test_multimodel_strat"})
|
freqai_conf.update({"strategy": "freqai_test_multimodel_strat"})
|
||||||
freqai_conf.update({"freqaimodel": "LightGBMPredictionMultiModel"})
|
freqai_conf.update({"freqaimodel": "LightGBMRegressorMultiTarget"})
|
||||||
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
||||||
exchange = get_patched_exchange(mocker, freqai_conf)
|
exchange = get_patched_exchange(mocker, freqai_conf)
|
||||||
strategy.dp = DataProvider(freqai_conf, exchange)
|
strategy.dp = DataProvider(freqai_conf, exchange)
|
||||||
@ -73,8 +73,9 @@ def test_train_model_in_series_LightGBMMultiModel(mocker, freqai_conf):
|
|||||||
@pytest.mark.skipif("arm" in platform.uname()[-1], reason="no ARM for Catboost ...")
|
@pytest.mark.skipif("arm" in platform.uname()[-1], reason="no ARM for Catboost ...")
|
||||||
def test_train_model_in_series_Catboost(mocker, freqai_conf):
|
def test_train_model_in_series_Catboost(mocker, freqai_conf):
|
||||||
freqai_conf.update({"timerange": "20180110-20180130"})
|
freqai_conf.update({"timerange": "20180110-20180130"})
|
||||||
freqai_conf.update({"freqaimodel": "CatboostPredictionModel"})
|
freqai_conf.update({"freqaimodel": "CatboostRegressor"})
|
||||||
del freqai_conf['freqai']['model_training_parameters']['verbosity']
|
freqai_conf.get('freqai', {}).update(
|
||||||
|
{'model_training_parameters': {"n_estimators": 100, "verbose": 0}})
|
||||||
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
||||||
exchange = get_patched_exchange(mocker, freqai_conf)
|
exchange = get_patched_exchange(mocker, freqai_conf)
|
||||||
strategy.dp = DataProvider(freqai_conf, exchange)
|
strategy.dp = DataProvider(freqai_conf, exchange)
|
||||||
|
Loading…
Reference in New Issue
Block a user