add continual learning to catboost and friends
This commit is contained in:
@@ -3,6 +3,7 @@ import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from catboost import CatBoostRegressor, Pool
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
|
||||
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
|
||||
|
||||
@@ -17,7 +18,7 @@ class CatboostRegressor(BaseRegressionModel):
|
||||
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||
"""
|
||||
|
||||
def fit(self, data_dictionary: Dict) -> Any:
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any:
|
||||
"""
|
||||
User sets up the training and test data to fit their desired model here
|
||||
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||
@@ -38,16 +39,16 @@ class CatboostRegressor(BaseRegressionModel):
|
||||
weight=data_dictionary["test_weights"],
|
||||
)
|
||||
|
||||
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
|
||||
init_model = None
|
||||
else:
|
||||
init_model = self.dd.model_dictionary[dk.pair]
|
||||
|
||||
model = CatBoostRegressor(
|
||||
allow_writing_files=False,
|
||||
**self.model_training_parameters,
|
||||
)
|
||||
|
||||
model.fit(X=train_data, eval_set=test_data)
|
||||
|
||||
# some evidence that catboost pools have memory leaks:
|
||||
# https://github.com/catboost/catboost/issues/1835
|
||||
del train_data, test_data
|
||||
gc.collect()
|
||||
model.fit(X=train_data, eval_set=test_data, init_model=init_model)
|
||||
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user