make BaseClassifierModel. Add predict_proba to lightgbm

This commit is contained in:
robcaulk
2022-08-13 20:07:31 +02:00
parent 31be707cc8
commit 58de20af0f
4 changed files with 108 additions and 42 deletions

View File

@@ -1,17 +1,15 @@
import logging
from typing import Any, Dict, Tuple
import pandas as pd
from pandas import DataFrame
from typing import Any, Dict
from catboost import CatBoostClassifier, Pool
import numpy.typing as npt
import numpy as np
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.BaseClassifierModel import BaseClassifierModel
logger = logging.getLogger(__name__)
class CatboostClassifier(BaseRegressionModel):
class CatboostClassifier(BaseClassifierModel):
"""
User created prediction model. The class needs to override three necessary
functions, predict(), train(), fit(). The class inherits ModelHandler which
@@ -41,34 +39,3 @@ class CatboostClassifier(BaseRegressionModel):
cbr.fit(train_data)
return cbr
def predict(
self, unfiltered_dataframe: DataFrame, dk: FreqaiDataKitchen, first: bool = False
) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
"""
Filter the prediction features data and predict with it.
:param: unfiltered_dataframe: Full dataframe for the current backtest period.
:return:
:pred_df: dataframe containing the predictions
:do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
data (NaNs) or felt uncertain about data (PCA and DI index)
"""
dk.find_features(unfiltered_dataframe)
filtered_dataframe, _ = dk.filter_features(
unfiltered_dataframe, dk.training_features_list, training_filter=False
)
filtered_dataframe = dk.normalize_data_from_metadata(filtered_dataframe)
dk.data_dictionary["prediction_features"] = filtered_dataframe
self.data_cleaning_predict(dk, filtered_dataframe)
predictions = self.model.predict(dk.data_dictionary["prediction_features"])
pred_df = DataFrame(predictions, columns=dk.label_list)
predictions_prob = self.model.predict_proba(dk.data_dictionary["prediction_features"])
pred_df_prob = DataFrame(predictions_prob, columns=self.model.classes_)
pred_df = pd.concat([pred_df, pred_df_prob], axis=1)
return (pred_df, dk.do_predict)