add continual learning to catboost and friends

This commit is contained in:
robcaulk
2022-09-06 20:30:37 +02:00
parent dc4a4bdf09
commit 97077ba18a
11 changed files with 48 additions and 24 deletions

View File

@@ -2,7 +2,7 @@ import logging
from typing import Any, Dict
from catboost import CatBoostClassifier, Pool
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel
@@ -16,7 +16,7 @@ class CatboostClassifier(BaseClassifierModel):
has its own DataHandler where data is held, saved, loaded, and managed.
"""
def fit(self, data_dictionary: Dict) -> Any:
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen) -> Any:
"""
User sets up the training and test data to fit their desired model here
:params:
@@ -36,6 +36,11 @@ class CatboostClassifier(BaseClassifierModel):
**self.model_training_parameters,
)
cbr.fit(train_data)
if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
init_model = None
else:
init_model = self.dd.model_dictionary[dk.pair]
cbr.fit(train_data, init_model=init_model)
return cbr