Merge pull request #7367 from freqtrade/add-continual-learning
add continual learning to catboost and friends
This commit is contained in:
@@ -88,6 +88,7 @@ class IFreqaiModel(ABC):
|
||||
self.begin_time: float = 0
|
||||
self.begin_time_train: float = 0
|
||||
self.base_tf_seconds = timeframe_to_seconds(self.config['timeframe'])
|
||||
self.continual_learning = self.freqai_info.get('continual_learning', False)
|
||||
|
||||
self._threads: List[threading.Thread] = []
|
||||
self._stop_event = threading.Event()
|
||||
@@ -676,21 +677,30 @@ class IFreqaiModel(ABC):
|
||||
self.train_time = 0
|
||||
return
|
||||
|
||||
def get_init_model(self, pair: str) -> Any:
|
||||
if pair not in self.dd.model_dictionary or not self.continual_learning:
|
||||
init_model = None
|
||||
else:
|
||||
init_model = self.dd.model_dictionary[pair]
|
||||
|
||||
return init_model
|
||||
|
||||
# Following methods which are overridden by user made prediction models.
|
||||
# See freqai/prediction_models/CatboostPredictionModel.py for an example.
|
||||
|
||||
@abstractmethod
|
||||
def train(self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen) -> Any:
|
||||
def train(self, unfiltered_df: DataFrame, pair: str,
|
||||
dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
Filter the training data and train a model to it. Train makes heavy use of the datahandler
|
||||
for storing, saving, loading, and analyzing the data.
|
||||
:param unfiltered_dataframe: Full dataframe for the current training period
|
||||
:param unfiltered_df: Full dataframe for the current training period
|
||||
:param metadata: pair metadata from strategy.
|
||||
:return: Trained model which can be used to inference (self.predict)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, data_dictionary: Dict[str, Any]) -> Any:
|
||||
def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
Most regressors use the same function names and arguments e.g. user
|
||||
can drop in LGBMRegressor in place of CatBoostRegressor and all data
|
||||
@@ -703,11 +713,11 @@ class IFreqaiModel(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def predict(
|
||||
self, dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = True
|
||||
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
|
||||
) -> Tuple[DataFrame, NDArray[np.int_]]:
|
||||
"""
|
||||
Filter the prediction features data and predict with it.
|
||||
:param unfiltered_dataframe: Full dataframe for the current backtest period.
|
||||
:param unfiltered_df: Full dataframe for the current backtest period.
|
||||
:param dk: FreqaiDataKitchen = Data management/analysis tool associated to present pair only
|
||||
:param first: boolean = whether this is the first prediction or not.
|
||||
:return:
|
||||
|
||||
Reference in New Issue
Block a user