diff --git a/freqtrade/freqai/base_models/BaseClassifierModel.py b/freqtrade/freqai/base_models/BaseClassifierModel.py index 17bffa85b..a5cea879f 100644 --- a/freqtrade/freqai/base_models/BaseClassifierModel.py +++ b/freqtrade/freqai/base_models/BaseClassifierModel.py @@ -95,6 +95,9 @@ class BaseClassifierModel(IFreqaiModel): self.data_cleaning_predict(dk) predictions = self.model.predict(dk.data_dictionary["prediction_features"]) + if self.CONV_WIDTH == 1: + predictions = np.reshape(predictions, (-1, len(dk.label_list))) + pred_df = DataFrame(predictions, columns=dk.label_list) predictions_prob = self.model.predict_proba(dk.data_dictionary["prediction_features"]) diff --git a/freqtrade/freqai/base_models/BaseRegressionModel.py b/freqtrade/freqai/base_models/BaseRegressionModel.py index 766579cb6..1f9b4f5a6 100644 --- a/freqtrade/freqai/base_models/BaseRegressionModel.py +++ b/freqtrade/freqai/base_models/BaseRegressionModel.py @@ -95,6 +95,9 @@ class BaseRegressionModel(IFreqaiModel): self.data_cleaning_predict(dk) predictions = self.model.predict(dk.data_dictionary["prediction_features"]) + if self.CONV_WIDTH == 1: + predictions = np.reshape(predictions, (-1, len(dk.label_list))) + pred_df = DataFrame(predictions, columns=dk.label_list) pred_df = dk.denormalize_labels_from_metadata(pred_df)