make BaseClassifierModel. Add predict_proba to lightgbm
This commit is contained in:
		| @@ -12,7 +12,7 @@ import pandas as pd | ||||
| import rapidjson | ||||
| from joblib import dump, load | ||||
| from joblib.externals import cloudpickle | ||||
| from numpy.typing import ArrayLike, NDArray | ||||
| from numpy.typing import NDArray | ||||
| from pandas import DataFrame | ||||
|  | ||||
| from freqtrade.configuration import TimeRange | ||||
|   | ||||
							
								
								
									
										99
									
								
								freqtrade/freqai/prediction_models/BaseClassifierModel.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								freqtrade/freqai/prediction_models/BaseClassifierModel.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,99 @@ | ||||
| import logging | ||||
| from typing import Any, Tuple | ||||
|  | ||||
| import numpy as np | ||||
| import numpy.typing as npt | ||||
| import pandas as pd | ||||
| from pandas import DataFrame | ||||
|  | ||||
| from freqtrade.freqai.data_kitchen import FreqaiDataKitchen | ||||
| from freqtrade.freqai.freqai_interface import IFreqaiModel | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| class BaseClassifierModel(IFreqaiModel): | ||||
|     """ | ||||
|     Base class for regression type models (e.g. Catboost, LightGBM, XGboost etc.). | ||||
|     User *must* inherit from this class and set fit() and predict(). See example scripts | ||||
|     such as prediction_models/CatboostPredictionModel.py for guidance. | ||||
|     """ | ||||
|  | ||||
|     def train( | ||||
|         self, unfiltered_dataframe: DataFrame, pair: str, dk: FreqaiDataKitchen | ||||
|     ) -> Any: | ||||
|         """ | ||||
|         Filter the training data and train a model to it. Train makes heavy use of the datakitchen | ||||
|         for storing, saving, loading, and analyzing the data. | ||||
|         :param unfiltered_dataframe: Full dataframe for the current training period | ||||
|         :param metadata: pair metadata from strategy. | ||||
|         :return: | ||||
|         :model: Trained model which can be used to inference (self.predict) | ||||
|         """ | ||||
|  | ||||
|         logger.info("-------------------- Starting training " f"{pair} --------------------") | ||||
|  | ||||
|         # filter the features requested by user in the configuration file and elegantly handle NaNs | ||||
|         features_filtered, labels_filtered = dk.filter_features( | ||||
|             unfiltered_dataframe, | ||||
|             dk.training_features_list, | ||||
|             dk.label_list, | ||||
|             training_filter=True, | ||||
|         ) | ||||
|  | ||||
|         start_date = unfiltered_dataframe["date"].iloc[0].strftime("%Y-%m-%d") | ||||
|         end_date = unfiltered_dataframe["date"].iloc[-1].strftime("%Y-%m-%d") | ||||
|         logger.info(f"-------------------- Training on data from {start_date} to " | ||||
|                     f"{end_date}--------------------") | ||||
|         # split data into train/test data. | ||||
|         data_dictionary = dk.make_train_test_datasets(features_filtered, labels_filtered) | ||||
|         if not self.freqai_info.get('fit_live_predictions', 0) or not self.live: | ||||
|             dk.fit_labels() | ||||
|         # normalize all data based on train_dataset only | ||||
|         data_dictionary = dk.normalize_data(data_dictionary) | ||||
|  | ||||
|         # optional additional data cleaning/analysis | ||||
|         self.data_cleaning_train(dk) | ||||
|  | ||||
|         logger.info( | ||||
|             f'Training model on {len(dk.data_dictionary["train_features"].columns)}' " features" | ||||
|         ) | ||||
|         logger.info(f'Training model on {len(data_dictionary["train_features"])} data points') | ||||
|  | ||||
|         model = self.fit(data_dictionary) | ||||
|  | ||||
|         logger.info(f"--------------------done training {pair}--------------------") | ||||
|  | ||||
|         return model | ||||
|  | ||||
|     def predict( | ||||
|         self, unfiltered_dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = False | ||||
|     ) -> Tuple[DataFrame, npt.NDArray[np.int_]]: | ||||
|         """ | ||||
|         Filter the prediction features data and predict with it. | ||||
|         :param: unfiltered_dataframe: Full dataframe for the current backtest period. | ||||
|         :return: | ||||
|         :pred_df: dataframe containing the predictions | ||||
|         :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove | ||||
|         data (NaNs) or felt uncertain about data (PCA and DI index) | ||||
|         """ | ||||
|  | ||||
|         dk.find_features(unfiltered_dataframe) | ||||
|         filtered_dataframe, _ = dk.filter_features( | ||||
|             unfiltered_dataframe, dk.training_features_list, training_filter=False | ||||
|         ) | ||||
|         filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe) | ||||
|         dk.data_dictionary["prediction_features"] = filtered_dataframe | ||||
|  | ||||
|         self.data_cleaning_predict(dk, filtered_dataframe) | ||||
|  | ||||
|         predictions = self.model.predict(dk.data_dictionary["prediction_features"]) | ||||
|         pred_df = DataFrame(predictions, columns=dk.label_list) | ||||
|  | ||||
|         predictions_prob = self.model.predict_proba(dk.data_dictionary["prediction_features"]) | ||||
|         pred_df_prob = DataFrame(predictions_prob, columns=self.model.classes_) | ||||
|  | ||||
|         pred_df = pd.concat([pred_df, pred_df_prob], axis=1) | ||||
|  | ||||
|         return (pred_df, dk.do_predict) | ||||
| @@ -1,17 +1,15 @@ | ||||
| import logging | ||||
| from typing import Any, Dict, Tuple | ||||
| import pandas as pd | ||||
| from pandas import DataFrame | ||||
| from typing import Any, Dict | ||||
|  | ||||
| from catboost import CatBoostClassifier, Pool | ||||
| import numpy.typing as npt | ||||
| import numpy as np | ||||
| from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel | ||||
| from freqtrade.freqai.data_kitchen import FreqaiDataKitchen | ||||
|  | ||||
| from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| class CatboostClassifier(BaseRegressionModel): | ||||
| class CatboostClassifier(BaseClassifierModel): | ||||
|     """ | ||||
|     User created prediction model. The class needs to override three necessary | ||||
|     functions, predict(), train(), fit(). The class inherits ModelHandler which | ||||
| @@ -41,34 +39,3 @@ class CatboostClassifier(BaseRegressionModel): | ||||
|         cbr.fit(train_data) | ||||
|  | ||||
|         return cbr | ||||
|  | ||||
|     def predict( | ||||
|         self, unfiltered_dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = False | ||||
|     ) -> Tuple[DataFrame, npt.NDArray[np.int_]]: | ||||
|         """ | ||||
|         Filter the prediction features data and predict with it. | ||||
|         :param: unfiltered_dataframe: Full dataframe for the current backtest period. | ||||
|         :return: | ||||
|         :pred_df: dataframe containing the predictions | ||||
|         :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove | ||||
|         data (NaNs) or felt uncertain about data (PCA and DI index) | ||||
|         """ | ||||
|  | ||||
|         dk.find_features(unfiltered_dataframe) | ||||
|         filtered_dataframe, _ = dk.filter_features( | ||||
|             unfiltered_dataframe, dk.training_features_list, training_filter=False | ||||
|         ) | ||||
|         filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe) | ||||
|         dk.data_dictionary["prediction_features"] = filtered_dataframe | ||||
|  | ||||
|         self.data_cleaning_predict(dk, filtered_dataframe) | ||||
|  | ||||
|         predictions = self.model.predict(dk.data_dictionary["prediction_features"]) | ||||
|         pred_df = DataFrame(predictions, columns=dk.label_list) | ||||
|  | ||||
|         predictions_prob = self.model.predict_proba(dk.data_dictionary["prediction_features"]) | ||||
|         pred_df_prob = DataFrame(predictions_prob, columns=self.model.classes_) | ||||
|  | ||||
|         pred_df = pd.concat([pred_df, pred_df_prob], axis=1) | ||||
|  | ||||
|         return (pred_df, dk.do_predict) | ||||
|   | ||||
| @@ -3,13 +3,13 @@ from typing import Any, Dict | ||||
|  | ||||
| from lightgbm import LGBMClassifier | ||||
|  | ||||
| from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel | ||||
| from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| class LightGBMClassifier(BaseRegressionModel): | ||||
| class LightGBMClassifier(BaseClassifierModel): | ||||
|     """ | ||||
|     User created prediction model. The class needs to override three necessary | ||||
|     functions, predict(), train(), fit(). The class inherits ModelHandler which | ||||
|   | ||||
		Reference in New Issue
	
	Block a user