From 58de20af0f1226dd5b3985d16d72bd979ec9fcf5 Mon Sep 17 00:00:00 2001 From: robcaulk Date: Sat, 13 Aug 2022 20:07:31 +0200 Subject: [PATCH] make BaseClassifierModel. Add predict_proba to lightgbm --- freqtrade/freqai/data_drawer.py | 2 +- .../prediction_models/BaseClassifierModel.py | 99 +++++++++++++++++++ .../prediction_models/CatboostClassifier.py | 45 ++------- .../prediction_models/LightGBMClassifier.py | 4 +- 4 files changed, 108 insertions(+), 42 deletions(-) create mode 100644 freqtrade/freqai/prediction_models/BaseClassifierModel.py diff --git a/freqtrade/freqai/data_drawer.py b/freqtrade/freqai/data_drawer.py index 02ef156bd..4ba55a0ec 100644 --- a/freqtrade/freqai/data_drawer.py +++ b/freqtrade/freqai/data_drawer.py @@ -12,7 +12,7 @@ import pandas as pd import rapidjson from joblib import dump, load from joblib.externals import cloudpickle -from numpy.typing import ArrayLike, NDArray +from numpy.typing import NDArray from pandas import DataFrame from freqtrade.configuration import TimeRange diff --git a/freqtrade/freqai/prediction_models/BaseClassifierModel.py b/freqtrade/freqai/prediction_models/BaseClassifierModel.py new file mode 100644 index 000000000..2edbf3b51 --- /dev/null +++ b/freqtrade/freqai/prediction_models/BaseClassifierModel.py @@ -0,0 +1,99 @@ +import logging +from typing import Any, Tuple + +import numpy as np +import numpy.typing as npt +import pandas as pd +from pandas import DataFrame + +from freqtrade.freqai.data_kitchen import FreqaiDataKitchen +from freqtrade.freqai.freqai_interface import IFreqaiModel + + +logger = logging.getLogger(__name__) + + +class BaseClassifierModel(IFreqaiModel): + """ + Base class for regression type models (e.g. Catboost, LightGBM, XGboost etc.). + User *must* inherit from this class and set fit() and predict(). See example scripts + such as prediction_models/CatboostPredictionModel.py for guidance. + """ + + def train( + self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen + ) -> Any: + """ + Filter the training data and train a model to it. Train makes heavy use of the datakitchen + for storing, saving, loading, and analyzing the data. + :param unfiltered_dataframe: Full dataframe for the current training period + :param metadata: pair metadata from strategy. + :return: + :model: Trained model which can be used to inference (self.predict) + """ + + logger.info("-------------------- Starting training " f"{pair} --------------------") + + # filter the features requested by user in the configuration file and elegantly handle NaNs + features_filtered, labels_filtered = dk.filter_features( + unfiltered_dataframe, + dk.training_features_list, + dk.label_list, + training_filter=True, + ) + + start_date = unfiltered_dataframe["date"].iloc[0].strftime("%Y-%m-%d") + end_date = unfiltered_dataframe["date"].iloc[-1].strftime("%Y-%m-%d") + logger.info(f"-------------------- Training on data from {start_date} to " + f"{end_date}--------------------") + # split data into train/test data. + data_dictionary = dk.make_train_test_datasets(features_filtered, labels_filtered) + if not self.freqai_info.get('fit_live_predictions', 0) or not self.live: + dk.fit_labels() + # normalize all data based on train_dataset only + data_dictionary = dk.normalize_data(data_dictionary) + + # optional additional data cleaning/analysis + self.data_cleaning_train(dk) + + logger.info( + f'Training model on {len(dk.data_dictionary["train_features"].columns)}' " features" + ) + logger.info(f'Training model on {len(data_dictionary["train_features"])} data points') + + model = self.fit(data_dictionary) + + logger.info(f"--------------------done training {pair}--------------------") + + return model + + def predict( + self, unfiltered_dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = False + ) -> Tuple[DataFrame, npt.NDArray[np.int_]]: + """ + Filter the prediction features data and predict with it. + :param: unfiltered_dataframe: Full dataframe for the current backtest period. + :return: + :pred_df: dataframe containing the predictions + :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove + data (NaNs) or felt uncertain about data (PCA and DI index) + """ + + dk.find_features(unfiltered_dataframe) + filtered_dataframe, _ = dk.filter_features( + unfiltered_dataframe, dk.training_features_list, training_filter=False + ) + filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe) + dk.data_dictionary["prediction_features"] = filtered_dataframe + + self.data_cleaning_predict(dk, filtered_dataframe) + + predictions = self.model.predict(dk.data_dictionary["prediction_features"]) + pred_df = DataFrame(predictions, columns=dk.label_list) + + predictions_prob = self.model.predict_proba(dk.data_dictionary["prediction_features"]) + pred_df_prob = DataFrame(predictions_prob, columns=self.model.classes_) + + pred_df = pd.concat([pred_df, pred_df_prob], axis=1) + + return (pred_df, dk.do_predict) diff --git a/freqtrade/freqai/prediction_models/CatboostClassifier.py b/freqtrade/freqai/prediction_models/CatboostClassifier.py index 7a4b06557..b88b28b25 100644 --- a/freqtrade/freqai/prediction_models/CatboostClassifier.py +++ b/freqtrade/freqai/prediction_models/CatboostClassifier.py @@ -1,17 +1,15 @@ import logging -from typing import Any, Dict, Tuple -import pandas as pd -from pandas import DataFrame +from typing import Any, Dict + from catboost import CatBoostClassifier, Pool -import numpy.typing as npt -import numpy as np -from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel -from freqtrade.freqai.data_kitchen import FreqaiDataKitchen + +from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel + logger = logging.getLogger(__name__) -class CatboostClassifier(BaseRegressionModel): +class CatboostClassifier(BaseClassifierModel): """ User created prediction model. The class needs to override three necessary functions, predict(), train(), fit(). The class inherits ModelHandler which @@ -41,34 +39,3 @@ class CatboostClassifier(BaseRegressionModel): cbr.fit(train_data) return cbr - - def predict( - self, unfiltered_dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = False - ) -> Tuple[DataFrame, npt.NDArray[np.int_]]: - """ - Filter the prediction features data and predict with it. - :param: unfiltered_dataframe: Full dataframe for the current backtest period. - :return: - :pred_df: dataframe containing the predictions - :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove - data (NaNs) or felt uncertain about data (PCA and DI index) - """ - - dk.find_features(unfiltered_dataframe) - filtered_dataframe, _ = dk.filter_features( - unfiltered_dataframe, dk.training_features_list, training_filter=False - ) - filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe) - dk.data_dictionary["prediction_features"] = filtered_dataframe - - self.data_cleaning_predict(dk, filtered_dataframe) - - predictions = self.model.predict(dk.data_dictionary["prediction_features"]) - pred_df = DataFrame(predictions, columns=dk.label_list) - - predictions_prob = self.model.predict_proba(dk.data_dictionary["prediction_features"]) - pred_df_prob = DataFrame(predictions_prob, columns=self.model.classes_) - - pred_df = pd.concat([pred_df, pred_df_prob], axis=1) - - return (pred_df, dk.do_predict) diff --git a/freqtrade/freqai/prediction_models/LightGBMClassifier.py b/freqtrade/freqai/prediction_models/LightGBMClassifier.py index 782dbce35..bafb16a39 100644 --- a/freqtrade/freqai/prediction_models/LightGBMClassifier.py +++ b/freqtrade/freqai/prediction_models/LightGBMClassifier.py @@ -3,13 +3,13 @@ from typing import Any, Dict from lightgbm import LGBMClassifier -from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel +from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel logger = logging.getLogger(__name__) -class LightGBMClassifier(BaseRegressionModel): +class LightGBMClassifier(BaseClassifierModel): """ User created prediction model. The class needs to override three necessary functions, predict(), train(), fit(). The class inherits ModelHandler which