Merge remote-tracking branch 'origin/develop' into feat/convolutional-neural-net
This commit is contained in:
@@ -95,9 +95,14 @@ class BaseClassifierModel(IFreqaiModel):
|
||||
self.data_cleaning_predict(dk)
|
||||
|
||||
predictions = self.model.predict(dk.data_dictionary["prediction_features"])
|
||||
if self.CONV_WIDTH == 1:
|
||||
predictions = np.reshape(predictions, (-1, len(dk.label_list)))
|
||||
|
||||
pred_df = DataFrame(predictions, columns=dk.label_list)
|
||||
|
||||
predictions_prob = self.model.predict_proba(dk.data_dictionary["prediction_features"])
|
||||
if self.CONV_WIDTH == 1:
|
||||
predictions_prob = np.reshape(predictions_prob, (-1, len(self.model.classes_)))
|
||||
pred_df_prob = DataFrame(predictions_prob, columns=self.model.classes_)
|
||||
|
||||
pred_df = pd.concat([pred_df, pred_df_prob], axis=1)
|
||||
|
@@ -95,6 +95,9 @@ class BaseRegressionModel(IFreqaiModel):
|
||||
self.data_cleaning_predict(dk)
|
||||
|
||||
predictions = self.model.predict(dk.data_dictionary["prediction_features"])
|
||||
if self.CONV_WIDTH == 1:
|
||||
predictions = np.reshape(predictions, (-1, len(dk.label_list)))
|
||||
|
||||
pred_df = DataFrame(predictions, columns=dk.label_list)
|
||||
|
||||
pred_df = dk.denormalize_labels_from_metadata(pred_df)
|
||||
|
Reference in New Issue
Block a user