add continual learning to catboost and friends

This commit is contained in:
robcaulk
2022-09-06 20:30:37 +02:00
parent dc4a4bdf09
commit 97077ba18a
11 changed files with 48 additions and 24 deletions

View File

@@ -86,6 +86,7 @@ class IFreqaiModel(ABC):
self.begin_time: float = 0
self.begin_time_train: float = 0
self.base_tf_seconds = timeframe_to_seconds(self.config['timeframe'])
self.continual_learning = self.freqai_info.get('continual_learning', False)
self._threads: List[threading.Thread] = []
self._stop_event = threading.Event()
@@ -674,7 +675,7 @@ class IFreqaiModel(ABC):
"""
@abstractmethod
def fit(self, data_dictionary: Dict[str, Any]) -> Any:
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen) -> Any:
"""
Most regressors use the same function names and arguments e.g. user
can drop in LGBMRegressor in place of CatBoostRegressor and all data

View File

@@ -61,7 +61,7 @@ class BaseClassifierModel(IFreqaiModel):
)
logger.info(f'Training model on {len(data_dictionary["train_features"])} data points')
model = self.fit(data_dictionary)
model = self.fit(data_dictionary, dk)
logger.info(f"--------------------done training {pair}--------------------")

View File

@@ -60,7 +60,7 @@ class BaseRegressionModel(IFreqaiModel):
)
logger.info(f'Training model on {len(data_dictionary["train_features"])} data points')
model = self.fit(data_dictionary)
model = self.fit(data_dictionary, dk)
logger.info(f"--------------------done training {pair}--------------------")

View File

@@ -57,7 +57,7 @@ class BaseTensorFlowModel(IFreqaiModel):
)
logger.info(f'Training model on {len(data_dictionary["train_features"])} data points')
model = self.fit(data_dictionary)
model = self.fit(data_dictionary, dk)
logger.info(f"--------------------done training {pair}--------------------")

View File

@@ -2,7 +2,7 @@ import logging
from typing import Any, Dict
from catboost import CatBoostClassifier, Pool
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel
@@ -16,7 +16,7 @@ class CatboostClassifier(BaseClassifierModel):
has its own DataHandler where data is held, saved, loaded, and managed.
"""
def fit(self, data_dictionary: Dict) -> Any:
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any:
"""
User sets up the training and test data to fit their desired model here
:params:
@@ -36,6 +36,11 @@ class CatboostClassifier(BaseClassifierModel):
**self.model_training_parameters,
)
cbr.fit(train_data)
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
init_model = None
else:
init_model = self.dd.model_dictionary[dk.pair]
cbr.fit(train_data, init_model=init_model)
return cbr

View File

@@ -3,6 +3,7 @@ import logging
from typing import Any, Dict
from catboost import CatBoostRegressor, Pool
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
@@ -17,7 +18,7 @@ class CatboostRegressor(BaseRegressionModel):
has its own DataHandler where data is held, saved, loaded, and managed.
"""
def fit(self, data_dictionary: Dict) -> Any:
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any:
"""
User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold
@@ -38,16 +39,16 @@ class CatboostRegressor(BaseRegressionModel):
weight=data_dictionary["test_weights"],
)
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
init_model = None
else:
init_model = self.dd.model_dictionary[dk.pair]
model = CatBoostRegressor(
allow_writing_files=False,
**self.model_training_parameters,
)
model.fit(X=train_data, eval_set=test_data)
# some evidence that catboost pools have memory leaks:
# https://github.com/catboost/catboost/issues/1835
del train_data, test_data
gc.collect()
model.fit(X=train_data, eval_set=test_data, init_model=init_model)
return model

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict
from catboost import CatBoostRegressor # , Pool
from sklearn.multioutput import MultiOutputRegressor
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
@@ -17,7 +17,7 @@ class CatboostRegressorMultiTarget(BaseRegressionModel):
has its own DataHandler where data is held, saved, loaded, and managed.
"""
def fit(self, data_dictionary: Dict) -> Any:
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any:
"""
User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold
@@ -34,6 +34,9 @@ class CatboostRegressorMultiTarget(BaseRegressionModel):
eval_set = (data_dictionary["test_features"], data_dictionary["test_labels"])
sample_weight = data_dictionary["train_weights"]
if self.continual_learning:
logger.warning('Continual learning not supported for MultiTarget models')
model = MultiOutputRegressor(estimator=cbr)
model.fit(X=X, y=y, sample_weight=sample_weight) # , eval_set=eval_set)

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict
from lightgbm import LGBMClassifier
from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
logger = logging.getLogger(__name__)
@@ -16,7 +16,7 @@ class LightGBMClassifier(BaseClassifierModel):
has its own DataHandler where data is held, saved, loaded, and managed.
"""
def fit(self, data_dictionary: Dict) -> Any:
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any:
"""
User sets up the training and test data to fit their desired model here
:params:
@@ -35,9 +35,14 @@ class LightGBMClassifier(BaseClassifierModel):
y = data_dictionary["train_labels"].to_numpy()[:, 0]
train_weights = data_dictionary["train_weights"]
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
init_model = None
else:
init_model = self.dd.model_dictionary[dk.pair]
model = LGBMClassifier(**self.model_training_parameters)
model.fit(X=X, y=y, eval_set=eval_set, sample_weight=train_weights,
eval_sample_weight=[test_weights])
eval_sample_weight=[test_weights], init_model=init_model)
return model

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict
from lightgbm import LGBMRegressor
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
logger = logging.getLogger(__name__)
@@ -16,7 +16,7 @@ class LightGBMRegressor(BaseRegressionModel):
has its own DataHandler where data is held, saved, loaded, and managed.
"""
def fit(self, data_dictionary: Dict) -> Any:
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any:
"""
Most regressors use the same function names and arguments e.g. user
can drop in LGBMRegressor in place of CatBoostRegressor and all data
@@ -35,9 +35,14 @@ class LightGBMRegressor(BaseRegressionModel):
y = data_dictionary["train_labels"]
train_weights = data_dictionary["train_weights"]
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
init_model = None
else:
init_model = self.dd.model_dictionary[dk.pair]
model = LGBMRegressor(**self.model_training_parameters)
model.fit(X=X, y=y, eval_set=eval_set, sample_weight=train_weights,
eval_sample_weight=[eval_weights])
eval_sample_weight=[eval_weights], init_model=init_model)
return model

View File

@@ -5,7 +5,7 @@ from lightgbm import LGBMRegressor
from sklearn.multioutput import MultiOutputRegressor
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
logger = logging.getLogger(__name__)
@@ -17,7 +17,7 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel):
has its own DataHandler where data is held, saved, loaded, and managed.
"""
def fit(self, data_dictionary: Dict) -> Any:
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any:
"""
User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold
@@ -31,6 +31,9 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel):
eval_set = (data_dictionary["test_features"], data_dictionary["test_labels"])
sample_weight = data_dictionary["train_weights"]
if self.continual_learning:
logger.warning('Continual learning not supported for MultiTarget models')
model = MultiOutputRegressor(estimator=lgb)
model.fit(X=X, y=y, sample_weight=sample_weight) # , eval_set=eval_set)
train_score = model.score(X, y)