add continual learning to catboost and friends

This commit is contained in:
robcaulk
2022-09-06 20:30:37 +02:00
parent dc4a4bdf09
commit 97077ba18a
11 changed files with 48 additions and 24 deletions

View File

@@ -5,7 +5,7 @@ from lightgbm import LGBMRegressor
from sklearn.multioutput import MultiOutputRegressor
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
logger = logging.getLogger(__name__)
@@ -17,7 +17,7 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel):
has its own DataHandler where data is held, saved, loaded, and managed.
"""
def fit(self, data_dictionary: Dict) -> Any:
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any:
"""
User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold
@@ -31,6 +31,9 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel):
eval_set = (data_dictionary["test_features"], data_dictionary["test_labels"])
sample_weight = data_dictionary["train_weights"]
if self.continual_learning:
logger.warning('Continual learning not supported for MultiTarget models')
model = MultiOutputRegressor(estimator=lgb)
model.fit(X=X, y=y, sample_weight=sample_weight) # , eval_set=eval_set)
train_score = model.score(X, y)