add classifier, improve model naming scheme
This commit is contained in:
parent
ce8fbbf743
commit
07763d0d4f
@ -226,6 +226,7 @@ class FreqaiDataDrawer:
|
||||
historical candles, and also stores historical predictions despite retrainings (so stored
|
||||
predictions are true predictions, not just inferencing on trained data)
|
||||
"""
|
||||
|
||||
# dynamic df returned to strategy and plotted in frequi
|
||||
mrv_df = self.model_return_values[pair] = pd.DataFrame()
|
||||
|
||||
@ -246,6 +247,8 @@ class FreqaiDataDrawer:
|
||||
else:
|
||||
for label in dk.label_list:
|
||||
mrv_df[label] = pred_df[label]
|
||||
if mrv_df[label].dtype == object:
|
||||
continue
|
||||
mrv_df[f"{label}_mean"] = dk.data["labels_mean"][label]
|
||||
mrv_df[f"{label}_std"] = dk.data["labels_std"][label]
|
||||
|
||||
@ -295,6 +298,8 @@ class FreqaiDataDrawer:
|
||||
|
||||
for label in dk.label_list:
|
||||
df[label].iloc[-1] = predictions[label].iloc[-1]
|
||||
if df[label].dtype == object:
|
||||
continue
|
||||
df[f"{label}_mean"].iloc[-1] = dk.data["labels_mean"][label]
|
||||
df[f"{label}_std"].iloc[-1] = dk.data["labels_std"][label]
|
||||
|
||||
|
@ -294,7 +294,7 @@ class FreqaiDataKitchen:
|
||||
self.data[item + "_min"] = train_min[item]
|
||||
|
||||
for item in data_dictionary["train_labels"].keys():
|
||||
if data_dictionary["train_labels"][item].dtype == str:
|
||||
if data_dictionary["train_labels"][item].dtype == object:
|
||||
continue
|
||||
train_labels_max = data_dictionary["train_labels"][item].max()
|
||||
train_labels_min = data_dictionary["train_labels"][item].min()
|
||||
@ -1010,6 +1010,8 @@ class FreqaiDataKitchen:
|
||||
|
||||
self.data["labels_mean"], self.data["labels_std"] = {}, {}
|
||||
for label in self.label_list:
|
||||
if self.data_dictionary["train_labels"][label].dtype == object:
|
||||
continue
|
||||
f = spy.stats.norm.fit(self.data_dictionary["train_labels"][label])
|
||||
self.data["labels_mean"][label], self.data["labels_std"][label] = f[0], f[1]
|
||||
|
||||
|
@ -123,7 +123,7 @@ class IFreqaiModel(ABC):
|
||||
|
||||
dataframe = dk.remove_features_from_df(dk.return_dataframe)
|
||||
del dk
|
||||
return self.return_values(dataframe)
|
||||
return dataframe
|
||||
|
||||
@threaded
|
||||
def start_scanning(self, strategy: IStrategy) -> None:
|
||||
@ -609,17 +609,6 @@ class IFreqaiModel(ABC):
|
||||
data (NaNs) or felt uncertain about data (i.e. SVM and/or DI index)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def return_values(self, dataframe: DataFrame) -> DataFrame:
|
||||
"""
|
||||
User defines the dataframe to be returned to strategy here.
|
||||
:param dataframe: DataFrame = the full dataframe for the current prediction (live)
|
||||
or --timerange (backtesting)
|
||||
:return: dataframe: DataFrame = dataframe filled with user defined data
|
||||
"""
|
||||
|
||||
return
|
||||
|
||||
def analyze_trade_database(self, dk: FreqaiDataKitchen, pair: str) -> None:
|
||||
"""
|
||||
User analyzes the trade database here and returns summary stats which will be passed back
|
||||
|
@ -19,15 +19,6 @@ class BaseRegressionModel(IFreqaiModel):
|
||||
such as prediction_models/CatboostPredictionModel.py for guidance.
|
||||
"""
|
||||
|
||||
def return_values(self, dataframe: DataFrame) -> DataFrame:
|
||||
"""
|
||||
User uses this function to add any additional return values to the dataframe.
|
||||
e.g.
|
||||
dataframe['volatility'] = dk.volatility_values
|
||||
"""
|
||||
|
||||
return dataframe
|
||||
|
||||
def train(
|
||||
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen
|
||||
) -> Any:
|
||||
|
@ -16,15 +16,6 @@ class BaseTensorFlowModel(IFreqaiModel):
|
||||
User *must* inherit from this class and set fit() and predict().
|
||||
"""
|
||||
|
||||
def return_values(self, dataframe: DataFrame) -> DataFrame:
|
||||
"""
|
||||
User uses this function to add any additional return values to the dataframe.
|
||||
e.g.
|
||||
dataframe['volatility'] = dk.volatility_values
|
||||
"""
|
||||
|
||||
return dataframe
|
||||
|
||||
def train(
|
||||
self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen
|
||||
) -> Any:
|
||||
|
44
freqtrade/freqai/prediction_models/CatboostClassifier.py
Normal file
44
freqtrade/freqai/prediction_models/CatboostClassifier.py
Normal file
@ -0,0 +1,44 @@
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from catboost import CatBoostClassifier, Pool
|
||||
|
||||
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CatboostClassifier(BaseRegressionModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict) -> Any:
|
||||
"""
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:params:
|
||||
:data_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
"""
|
||||
|
||||
train_data = Pool(
|
||||
data=data_dictionary["train_features"],
|
||||
label=data_dictionary["train_labels"],
|
||||
weight=data_dictionary["train_weights"],
|
||||
)
|
||||
|
||||
cbr = CatBoostClassifier(
|
||||
allow_writing_files=False,
|
||||
gpu_ram_part=0.5,
|
||||
verbose=100,
|
||||
early_stopping_rounds=400,
|
||||
loss_function='MultiClass',
|
||||
**self.model_training_parameters,
|
||||
)
|
||||
|
||||
cbr.fit(train_data)
|
||||
|
||||
return cbr
|
@ -10,7 +10,7 @@ from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressio
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CatboostPredictionModel(BaseRegressionModel):
|
||||
class CatboostRegressor(BaseRegressionModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
@ -10,7 +10,7 @@ from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressio
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CatboostPredictionMultiModel(BaseRegressionModel):
|
||||
class CatboostRegressorMultiTarget(BaseRegressionModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
@ -9,7 +9,7 @@ from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressio
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LightGBMPredictionModel(BaseRegressionModel):
|
||||
class LightGBMRegressor(BaseRegressionModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
@ -10,7 +10,7 @@ from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressio
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LightGBMPredictionMultiModel(BaseRegressionModel):
|
||||
class LightGBMRegressorMultiTarget(BaseRegressionModel):
|
||||
"""
|
||||
User created prediction model. The class needs to override three necessary
|
||||
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
@ -21,7 +21,7 @@ def freqai_conf(default_conf, tmpdir):
|
||||
"strategy": "freqai_test_strat",
|
||||
"user_data_dir": Path(tmpdir),
|
||||
"strategy-path": "freqtrade/tests/strategy/strats",
|
||||
"freqaimodel": "LightGBMPredictionModel",
|
||||
"freqaimodel": "LightGBMRegressor",
|
||||
"freqaimodel_path": "freqai/prediction_models",
|
||||
"timerange": "20180110-20180115",
|
||||
"freqai": {
|
||||
|
@ -43,7 +43,7 @@ def test_train_model_in_series_LightGBM(mocker, freqai_conf):
|
||||
def test_train_model_in_series_LightGBMMultiModel(mocker, freqai_conf):
|
||||
freqai_conf.update({"timerange": "20180110-20180130"})
|
||||
freqai_conf.update({"strategy": "freqai_test_multimodel_strat"})
|
||||
freqai_conf.update({"freqaimodel": "LightGBMPredictionMultiModel"})
|
||||
freqai_conf.update({"freqaimodel": "LightGBMRegressorMultiTarget"})
|
||||
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
||||
exchange = get_patched_exchange(mocker, freqai_conf)
|
||||
strategy.dp = DataProvider(freqai_conf, exchange)
|
||||
@ -73,8 +73,9 @@ def test_train_model_in_series_LightGBMMultiModel(mocker, freqai_conf):
|
||||
@pytest.mark.skipif("arm" in platform.uname()[-1], reason="no ARM for Catboost ...")
|
||||
def test_train_model_in_series_Catboost(mocker, freqai_conf):
|
||||
freqai_conf.update({"timerange": "20180110-20180130"})
|
||||
freqai_conf.update({"freqaimodel": "CatboostPredictionModel"})
|
||||
del freqai_conf['freqai']['model_training_parameters']['verbosity']
|
||||
freqai_conf.update({"freqaimodel": "CatboostRegressor"})
|
||||
freqai_conf.get('freqai', {}).update(
|
||||
{'model_training_parameters': {"n_estimators": 100, "verbose": 0}})
|
||||
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
||||
exchange = get_patched_exchange(mocker, freqai_conf)
|
||||
strategy.dp = DataProvider(freqai_conf, exchange)
|
||||
|
Loading…
Reference in New Issue
Block a user