Merge pull request #7883 from freqtrade/fix/multioutput-bug

fix bug in MultiOutput* with conv_width = 1
This commit is contained in:
Matthias 2022-12-11 15:52:10 +01:00 committed by GitHub
commit 42afdbb0e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 0 deletions

View File

@ -95,9 +95,14 @@ class BaseClassifierModel(IFreqaiModel):
self.data_cleaning_predict(dk) self.data_cleaning_predict(dk)
predictions = self.model.predict(dk.data_dictionary["prediction_features"]) predictions = self.model.predict(dk.data_dictionary["prediction_features"])
if self.CONV_WIDTH == 1:
predictions = np.reshape(predictions, (-1, len(dk.label_list)))
pred_df = DataFrame(predictions, columns=dk.label_list) pred_df = DataFrame(predictions, columns=dk.label_list)
predictions_prob = self.model.predict_proba(dk.data_dictionary["prediction_features"]) predictions_prob = self.model.predict_proba(dk.data_dictionary["prediction_features"])
if self.CONV_WIDTH == 1:
predictions_prob = np.reshape(predictions_prob, (-1, len(self.model.classes_)))
pred_df_prob = DataFrame(predictions_prob, columns=self.model.classes_) pred_df_prob = DataFrame(predictions_prob, columns=self.model.classes_)
pred_df = pd.concat([pred_df, pred_df_prob], axis=1) pred_df = pd.concat([pred_df, pred_df_prob], axis=1)

View File

@ -95,6 +95,9 @@ class BaseRegressionModel(IFreqaiModel):
self.data_cleaning_predict(dk) self.data_cleaning_predict(dk)
predictions = self.model.predict(dk.data_dictionary["prediction_features"]) predictions = self.model.predict(dk.data_dictionary["prediction_features"])
if self.CONV_WIDTH == 1:
predictions = np.reshape(predictions, (-1, len(dk.label_list)))
pred_df = DataFrame(predictions, columns=dk.label_list) pred_df = DataFrame(predictions, columns=dk.label_list)
pred_df = dk.denormalize_labels_from_metadata(pred_df) pred_df = dk.denormalize_labels_from_metadata(pred_df)