From 85f22b5c3029a3f613d0b0da7b61eeef8f6685d5 Mon Sep 17 00:00:00 2001 From: robcaulk Date: Sun, 11 Dec 2022 12:15:19 +0100 Subject: [PATCH 1/2] fix bug in MultiOutput* with conv_width = 1 --- freqtrade/freqai/base_models/BaseClassifierModel.py | 3 +++ freqtrade/freqai/base_models/BaseRegressionModel.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/freqtrade/freqai/base_models/BaseClassifierModel.py b/freqtrade/freqai/base_models/BaseClassifierModel.py index 17bffa85b..a5cea879f 100644 --- a/freqtrade/freqai/base_models/BaseClassifierModel.py +++ b/freqtrade/freqai/base_models/BaseClassifierModel.py @@ -95,6 +95,9 @@ class BaseClassifierModel(IFreqaiModel): self.data_cleaning_predict(dk) predictions = self.model.predict(dk.data_dictionary["prediction_features"]) + if self.CONV_WIDTH == 1: + predictions = np.reshape(predictions, (-1, len(dk.label_list))) + pred_df = DataFrame(predictions, columns=dk.label_list) predictions_prob = self.model.predict_proba(dk.data_dictionary["prediction_features"]) diff --git a/freqtrade/freqai/base_models/BaseRegressionModel.py b/freqtrade/freqai/base_models/BaseRegressionModel.py index 766579cb6..1f9b4f5a6 100644 --- a/freqtrade/freqai/base_models/BaseRegressionModel.py +++ b/freqtrade/freqai/base_models/BaseRegressionModel.py @@ -95,6 +95,9 @@ class BaseRegressionModel(IFreqaiModel): self.data_cleaning_predict(dk) predictions = self.model.predict(dk.data_dictionary["prediction_features"]) + if self.CONV_WIDTH == 1: + predictions = np.reshape(predictions, (-1, len(dk.label_list))) + pred_df = DataFrame(predictions, columns=dk.label_list) pred_df = dk.denormalize_labels_from_metadata(pred_df) From 8c7ec07951eadf53a5722fe7d7489e9a95e5ab46 Mon Sep 17 00:00:00 2001 From: robcaulk Date: Sun, 11 Dec 2022 12:39:31 +0100 Subject: [PATCH 2/2] ensure predict_proba follows suit. Remove all lib specific params from example config --- config_examples/config_freqai.example.json | 1 - freqtrade/freqai/base_models/BaseClassifierModel.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/config_examples/config_freqai.example.json b/config_examples/config_freqai.example.json index 5e564a1fc..f58a4468b 100644 --- a/config_examples/config_freqai.example.json +++ b/config_examples/config_freqai.example.json @@ -80,7 +80,6 @@ "random_state": 1 }, "model_training_parameters": { - "n_estimators": 1000 } }, "bot_name": "", diff --git a/freqtrade/freqai/base_models/BaseClassifierModel.py b/freqtrade/freqai/base_models/BaseClassifierModel.py index a5cea879f..ffd42dd1d 100644 --- a/freqtrade/freqai/base_models/BaseClassifierModel.py +++ b/freqtrade/freqai/base_models/BaseClassifierModel.py @@ -101,6 +101,8 @@ class BaseClassifierModel(IFreqaiModel): pred_df = DataFrame(predictions, columns=dk.label_list) predictions_prob = self.model.predict_proba(dk.data_dictionary["prediction_features"]) + if self.CONV_WIDTH == 1: + predictions_prob = np.reshape(predictions_prob, (-1, len(self.model.classes_))) pred_df_prob = DataFrame(predictions_prob, columns=self.model.classes_) pred_df = pd.concat([pred_df, pred_df_prob], axis=1)