diff --git a/freqtrade/freqai/prediction_models/BaseTensorFlowModel.py b/freqtrade/freqai/prediction_models/BaseTensorFlowModel.py new file mode 100644 index 000000000..098ff24dd --- /dev/null +++ b/freqtrade/freqai/prediction_models/BaseTensorFlowModel.py @@ -0,0 +1,78 @@ +import logging +from typing import Tuple + +from pandas import DataFrame + +from freqtrade.freqai.data_kitchen import FreqaiDataKitchen +from freqtrade.freqai.freqai_interface import IFreqaiModel + + +logger = logging.getLogger(__name__) + + +class BaseTensorFlowModel(IFreqaiModel): + """ + Base class for TensorFlow type models. + User *must* inherit from this class and set fit() and predict(). + """ + + def return_values(self, dataframe: DataFrame, dk: FreqaiDataKitchen) -> DataFrame: + """ + User uses this function to add any additional return values to the dataframe. + e.g. + dataframe['volatility'] = dk.volatility_values + """ + + return dataframe + + def train( + self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen + ) -> Tuple[DataFrame, DataFrame]: + """ + Filter the training data and train a model to it. Train makes heavy use of the datakitchen + for storing, saving, loading, and analyzing the data. + :params: + :unfiltered_dataframe: Full dataframe for the current training period + :metadata: pair metadata from strategy. + :returns: + :model: Trained model which can be used to inference (self.predict) + """ + + logger.info("--------------------Starting training " f"{pair} --------------------") + + # filter the features requested by user in the configuration file and elegantly handle NaNs + features_filtered, labels_filtered = dk.filter_features( + unfiltered_dataframe, + dk.training_features_list, + dk.label_list, + training_filter=True, + ) + + # split data into train/test data. + data_dictionary = dk.make_train_test_datasets(features_filtered, labels_filtered) + if not self.freqai_info.get('fit_live_predictions', 0): + dk.fit_labels() + # normalize all data based on train_dataset only + data_dictionary = dk.normalize_data(data_dictionary) + + # optional additional data cleaning/analysis + self.data_cleaning_train(dk) + + logger.info( + f'Training model on {len(dk.data_dictionary["train_features"].columns)}' " features" + ) + logger.info(f'Training model on {len(data_dictionary["train_features"])} data points') + + model = self.fit(data_dictionary) + + if pair not in self.dd.historic_predictions: + self.set_initial_historic_predictions( + data_dictionary['train_features'], model, dk, pair) + elif self.freqai_info.get('fit_live_predictions_candles', 0): + dk.fit_live_predictions() + + self.dd.save_historic_predictions_to_disk() + + logger.info(f"--------------------done training {pair}--------------------") + + return model