debug classifier with predict proba
This commit is contained in:
		| @@ -358,12 +358,7 @@ class FreqaiDataDrawer: | ||||
|  | ||||
|         dk.find_features(dataframe) | ||||
|  | ||||
|         added_labels = [] | ||||
|         if dk.unique_classes: | ||||
|             for label in dk.unique_classes: | ||||
|                 added_labels += dk.unique_classes[label] | ||||
|     | ||||
|         full_labels = dk.label_list + added_labels | ||||
|         full_labels = dk.label_list + dk.unique_class_list | ||||
|  | ||||
|         for label in full_labels: | ||||
|             dataframe[label] = 0 | ||||
|   | ||||
| @@ -91,6 +91,7 @@ class FreqaiDataKitchen: | ||||
|         self.thread_count = self.freqai_config.get("data_kitchen_thread_count", -1) | ||||
|         self.train_dates: DataFrame = pd.DataFrame() | ||||
|         self.unique_classes: Dict[str, list] = {} | ||||
|         self.unique_class_list: list = [] | ||||
|  | ||||
|     def set_paths( | ||||
|         self, | ||||
| @@ -338,7 +339,7 @@ class FreqaiDataKitchen: | ||||
|         """ | ||||
|  | ||||
|         for label in df.columns: | ||||
|             if df[label].dtype == object: | ||||
|             if df[label].dtype == object or label in self.unique_class_list: | ||||
|                 continue | ||||
|             df[label] = ( | ||||
|                 (df[label] + 1) | ||||
| @@ -995,6 +996,10 @@ class FreqaiDataKitchen: | ||||
|             f = spy.stats.norm.fit(self.data_dictionary["train_labels"][label]) | ||||
|             self.data["labels_mean"][label], self.data["labels_std"][label] = f[0], f[1] | ||||
|  | ||||
|         # incase targets are classifications | ||||
|         for label in self.unique_class_list: | ||||
|             self.data["labels_mean"][label], self.data["labels_std"][label] = 0, 0 | ||||
|  | ||||
|         return | ||||
|  | ||||
|     def remove_features_from_df(self, dataframe: DataFrame) -> DataFrame: | ||||
| @@ -1014,3 +1019,7 @@ class FreqaiDataKitchen: | ||||
|         for key in self.label_list: | ||||
|             if dataframe[key].dtype == object: | ||||
|                 self.unique_classes[key] = dataframe[key].dropna().unique() | ||||
|  | ||||
|         if self.unique_classes: | ||||
|             for label in self.unique_classes: | ||||
|                 self.unique_class_list += list(self.unique_classes[label]) | ||||
|   | ||||
| @@ -320,6 +320,8 @@ class IFreqaiModel(ABC): | ||||
|             # allows FreqUI to show full return values. | ||||
|             pred_df, do_preds = self.predict(dataframe, dk) | ||||
|             self.dd.set_initial_return_values(pair, dk, pred_df, do_preds) | ||||
|             if pair not in self.dd.historic_predictions: | ||||
|                 self.set_initial_historic_predictions(pred_df, dk, pair) | ||||
|             dk.return_dataframe = self.dd.attach_return_values_to_return_dataframe(pair, dataframe) | ||||
|             return | ||||
|         elif self.dk.check_if_model_expired(trained_timestamp): | ||||
| @@ -336,6 +338,9 @@ class IFreqaiModel(ABC): | ||||
|             # historical accuracy reasons. | ||||
|             pred_df, do_preds = self.predict(dataframe.iloc[-self.CONV_WIDTH:], dk, first=False) | ||||
|  | ||||
|         self.dd.save_historic_predictions_to_disk() | ||||
|         if self.freqai_info.get('fit_live_predictions_candles', 0) and self.live: | ||||
|             self.fit_live_predictions(dk) | ||||
|         self.dd.append_model_predictions(pair, pred_df, do_preds, dk, len(dataframe)) | ||||
|         dk.return_dataframe = self.dd.attach_return_values_to_return_dataframe(pair, dataframe) | ||||
|  | ||||
| @@ -503,7 +508,7 @@ class IFreqaiModel(ABC): | ||||
|             self.dd.purge_old_models() | ||||
|  | ||||
|     def set_initial_historic_predictions( | ||||
|         self, df: DataFrame, model: Any, dk: FreqaiDataKitchen, pair: str | ||||
|         self, pred_df: DataFrame, dk: FreqaiDataKitchen, pair: str | ||||
|     ) -> None: | ||||
|         """ | ||||
|         This function is called only if the datadrawer failed to load an | ||||
| @@ -528,12 +533,6 @@ class IFreqaiModel(ABC): | ||||
|         :param: dk: FreqaiDataKitchen = object containing methods for data analysis | ||||
|         :param: pair: str = current pair | ||||
|         """ | ||||
|         num_candles = self.freqai_info.get('fit_live_predictions_candles', 600) | ||||
|         if not num_candles: | ||||
|             num_candles = 600 | ||||
|         df_tail = df.tail(num_candles) | ||||
|         trained_predictions = model.predict(df_tail) | ||||
|         pred_df = DataFrame(trained_predictions, columns=dk.label_list) | ||||
|  | ||||
|         pred_df = dk.denormalize_labels_from_metadata(pred_df) | ||||
|  | ||||
| @@ -560,9 +559,12 @@ class IFreqaiModel(ABC): | ||||
|         """ | ||||
|         import scipy as spy | ||||
|  | ||||
|         # add classes from classifier label types if used | ||||
|         full_labels = dk.label_list + dk.unique_class_list | ||||
|  | ||||
|         num_candles = self.freqai_info.get("fit_live_predictions_candles", 100) | ||||
|         dk.data["labels_mean"], dk.data["labels_std"] = {}, {} | ||||
|         for label in dk.label_list: | ||||
|         for label in full_labels: | ||||
|             if self.dd.historic_predictions[dk.pair][label].dtype == object: | ||||
|                 continue | ||||
|             f = spy.stats.norm.fit(self.dd.historic_predictions[dk.pair][label].tail(num_candles)) | ||||
|   | ||||
| @@ -62,15 +62,6 @@ class BaseRegressionModel(IFreqaiModel): | ||||
|  | ||||
|         model = self.fit(data_dictionary) | ||||
|  | ||||
|         if pair not in self.dd.historic_predictions: | ||||
|             self.set_initial_historic_predictions( | ||||
|                 data_dictionary['train_features'], model, dk, pair) | ||||
|  | ||||
|         if self.freqai_info.get('fit_live_predictions_candles', 0) and self.live: | ||||
|             self.fit_live_predictions(dk) | ||||
|  | ||||
|         self.dd.save_historic_predictions_to_disk() | ||||
|  | ||||
|         logger.info(f"--------------------done training {pair}--------------------") | ||||
|  | ||||
|         return model | ||||
|   | ||||
		Reference in New Issue
	
	Block a user