reformat code
This commit is contained in:
		| @@ -69,7 +69,7 @@ class PyTorchModelTrainer: | ||||
|  | ||||
|         self.model.eval() | ||||
|         epochs = self.calc_n_epochs( | ||||
|             n_obs=len(data_dictionary[f'test_features']), | ||||
|             n_obs=len(data_dictionary['test_features']), | ||||
|             batch_size=self.batch_size, | ||||
|             n_iters=self.eval_iters | ||||
|         ) | ||||
| @@ -101,8 +101,11 @@ class PyTorchModelTrainer: | ||||
|                 torch.from_numpy(data_dictionary[f'{split}_features'].values).float(), | ||||
|                 torch.from_numpy(data_dictionary[f'{split}_labels'].astype(float).values) | ||||
|                 .long() | ||||
|                 .view(labels_view)  # todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes]. need to resolve it per ClassifierModel | ||||
|                 .view(labels_view) | ||||
|             ) | ||||
|             # todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes]. | ||||
|             #  need to resolve it per ClassifierModel | ||||
|  | ||||
|             data_loader = DataLoader( | ||||
|                 dataset, | ||||
|                 batch_size=self.batch_size, | ||||
|   | ||||
| @@ -24,7 +24,6 @@ from freqtrade.exceptions import OperationalException | ||||
| from freqtrade.freqai.data_kitchen import FreqaiDataKitchen | ||||
| from freqtrade.strategy.interface import IStrategy | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| @@ -90,8 +89,8 @@ class FreqaiDataDrawer: | ||||
|         self.metric_tracker_lock = threading.Lock() | ||||
|         self.old_DBSCAN_eps: Dict[str, float] = {} | ||||
|         self.empty_pair_dict: pair_info = { | ||||
|                 "model_filename": "", "trained_timestamp": 0, | ||||
|                 "data_path": "", "extras": {}} | ||||
|             "model_filename": "", "trained_timestamp": 0, | ||||
|             "data_path": "", "extras": {}} | ||||
|         self.model_type = self.freqai_info.get('model_save_type', 'joblib') | ||||
|  | ||||
|     def update_metric_tracker(self, metric: str, value: float, pair: str) -> None: | ||||
| @@ -446,9 +445,9 @@ class FreqaiDataDrawer: | ||||
|             dump(model, save_path / f"{dk.model_filename}_model.joblib") | ||||
|         elif self.model_type == 'keras': | ||||
|             model.save(save_path / f"{dk.model_filename}_model.h5") | ||||
|         elif 'stable_baselines' in self.model_type or\ | ||||
|                 'sb3_contrib' == self.model_type or\ | ||||
|                 'pytorch' == self.model_type: | ||||
|         elif ('stable_baselines' in self.model_type or | ||||
|               'sb3_contrib' == self.model_type or | ||||
|               'pytorch' == self.model_type): | ||||
|             model.save(save_path / f"{dk.model_filename}_model.zip") | ||||
|  | ||||
|         if dk.svm_model is not None: | ||||
| @@ -581,16 +580,16 @@ class FreqaiDataDrawer: | ||||
|                     if len(df_dp.index) == 0: | ||||
|                         continue | ||||
|                     if str(hist_df.iloc[-1]["date"]) == str( | ||||
|                         df_dp.iloc[-1:]["date"].iloc[-1] | ||||
|                             df_dp.iloc[-1:]["date"].iloc[-1] | ||||
|                     ): | ||||
|                         continue | ||||
|  | ||||
|                     try: | ||||
|                         index = ( | ||||
|                             df_dp.loc[ | ||||
|                                 df_dp["date"] == hist_df.iloc[-1]["date"] | ||||
|                             ].index[0] | ||||
|                             + 1 | ||||
|                                 df_dp.loc[ | ||||
|                                     df_dp["date"] == hist_df.iloc[-1]["date"] | ||||
|                                     ].index[0] | ||||
|                                 + 1 | ||||
|                         ) | ||||
|                     except IndexError: | ||||
|                         if hist_df.iloc[-1]['date'] < df_dp['date'].iloc[0]: | ||||
| @@ -643,7 +642,7 @@ class FreqaiDataDrawer: | ||||
|                 ) | ||||
|  | ||||
|     def get_base_and_corr_dataframes( | ||||
|         self, timerange: TimeRange, pair: str, dk: FreqaiDataKitchen | ||||
|             self, timerange: TimeRange, pair: str, dk: FreqaiDataKitchen | ||||
|     ) -> Tuple[Dict[Any, Any], Dict[Any, Any]]: | ||||
|         """ | ||||
|         Searches through our historic_data in memory and returns the dataframes relevant | ||||
|   | ||||
		Reference in New Issue
	
	Block a user