debug classifier with predict proba

This commit is contained in:
robcaulk
2022-08-10 15:16:50 +02:00
parent 23cc21ce59
commit b1b76a2dbe
4 changed files with 21 additions and 24 deletions

View File

@@ -320,6 +320,8 @@ class IFreqaiModel(ABC):
# allows FreqUI to show full return values.
pred_df, do_preds = self.predict(dataframe, dk)
self.dd.set_initial_return_values(pair, dk, pred_df, do_preds)
if pair not in self.dd.historic_predictions:
self.set_initial_historic_predictions(pred_df, dk, pair)
dk.return_dataframe = self.dd.attach_return_values_to_return_dataframe(pair, dataframe)
return
elif self.dk.check_if_model_expired(trained_timestamp):
@@ -336,6 +338,9 @@ class IFreqaiModel(ABC):
# historical accuracy reasons.
pred_df, do_preds = self.predict(dataframe.iloc[-self.CONV_WIDTH:], dk, first=False)
self.dd.save_historic_predictions_to_disk()
if self.freqai_info.get('fit_live_predictions_candles', 0) and self.live:
self.fit_live_predictions(dk)
self.dd.append_model_predictions(pair, pred_df, do_preds, dk, len(dataframe))
dk.return_dataframe = self.dd.attach_return_values_to_return_dataframe(pair, dataframe)
@@ -503,7 +508,7 @@ class IFreqaiModel(ABC):
self.dd.purge_old_models()
def set_initial_historic_predictions(
self, df: DataFrame, model: Any, dk: FreqaiDataKitchen, pair: str
self, pred_df: DataFrame, dk: FreqaiDataKitchen, pair: str
) -> None:
"""
This function is called only if the datadrawer failed to load an
@@ -528,12 +533,6 @@ class IFreqaiModel(ABC):
:param: dk: FreqaiDataKitchen = object containing methods for data analysis
:param: pair: str = current pair
"""
num_candles = self.freqai_info.get('fit_live_predictions_candles', 600)
if not num_candles:
num_candles = 600
df_tail = df.tail(num_candles)
trained_predictions = model.predict(df_tail)
pred_df = DataFrame(trained_predictions, columns=dk.label_list)
pred_df = dk.denormalize_labels_from_metadata(pred_df)
@@ -560,9 +559,12 @@ class IFreqaiModel(ABC):
"""
import scipy as spy
# add classes from classifier label types if used
full_labels = dk.label_list + dk.unique_class_list
num_candles = self.freqai_info.get("fit_live_predictions_candles", 100)
dk.data["labels_mean"], dk.data["labels_std"] = {}, {}
for label in dk.label_list:
for label in full_labels:
if self.dd.historic_predictions[dk.pair][label].dtype == object:
continue
f = spy.stats.norm.fit(self.dd.historic_predictions[dk.pair][label].tail(num_candles))