diff --git a/freqtrade/freqai/base_models/BaseClassifierModel.py b/freqtrade/freqai/base_models/BaseClassifierModel.py index 17bffa85b..ffd42dd1d 100644 --- a/freqtrade/freqai/base_models/BaseClassifierModel.py +++ b/freqtrade/freqai/base_models/BaseClassifierModel.py @@ -95,9 +95,14 @@ class BaseClassifierModel(IFreqaiModel): self.data_cleaning_predict(dk) predictions = self.model.predict(dk.data_dictionary["prediction_features"]) + if self.CONV_WIDTH == 1: + predictions = np.reshape(predictions, (-1, len(dk.label_list))) + pred_df = DataFrame(predictions, columns=dk.label_list) predictions_prob = self.model.predict_proba(dk.data_dictionary["prediction_features"]) + if self.CONV_WIDTH == 1: + predictions_prob = np.reshape(predictions_prob, (-1, len(self.model.classes_))) pred_df_prob = DataFrame(predictions_prob, columns=self.model.classes_) pred_df = pd.concat([pred_df, pred_df_prob], axis=1) diff --git a/freqtrade/freqai/base_models/BaseRegressionModel.py b/freqtrade/freqai/base_models/BaseRegressionModel.py index 766579cb6..1f9b4f5a6 100644 --- a/freqtrade/freqai/base_models/BaseRegressionModel.py +++ b/freqtrade/freqai/base_models/BaseRegressionModel.py @@ -95,6 +95,9 @@ class BaseRegressionModel(IFreqaiModel): self.data_cleaning_predict(dk) predictions = self.model.predict(dk.data_dictionary["prediction_features"]) + if self.CONV_WIDTH == 1: + predictions = np.reshape(predictions, (-1, len(dk.label_list))) + pred_df = DataFrame(predictions, columns=dk.label_list) pred_df = dk.denormalize_labels_from_metadata(pred_df)