From e6e747bcd819b28336dbf4232c6d23226102e6bf Mon Sep 17 00:00:00 2001 From: Yinon Polak Date: Mon, 6 Mar 2023 17:50:02 +0200 Subject: [PATCH] reformat code --- .../freqai/base_models/PyTorchModelTrainer.py | 7 ++++-- freqtrade/freqai/data_drawer.py | 23 +++++++++---------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/freqtrade/freqai/base_models/PyTorchModelTrainer.py b/freqtrade/freqai/base_models/PyTorchModelTrainer.py index 03d264371..992ad37ef 100644 --- a/freqtrade/freqai/base_models/PyTorchModelTrainer.py +++ b/freqtrade/freqai/base_models/PyTorchModelTrainer.py @@ -69,7 +69,7 @@ class PyTorchModelTrainer: self.model.eval() epochs = self.calc_n_epochs( - n_obs=len(data_dictionary[f'test_features']), + n_obs=len(data_dictionary['test_features']), batch_size=self.batch_size, n_iters=self.eval_iters ) @@ -101,8 +101,11 @@ class PyTorchModelTrainer: torch.from_numpy(data_dictionary[f'{split}_features'].values).float(), torch.from_numpy(data_dictionary[f'{split}_labels'].astype(float).values) .long() - .view(labels_view) # todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes]. need to resolve it per ClassifierModel + .view(labels_view) ) + # todo currently assuming class labels are strings ['0.0', '1.0' .. n_classes]. + # need to resolve it per ClassifierModel + data_loader = DataLoader( dataset, batch_size=self.batch_size, diff --git a/freqtrade/freqai/data_drawer.py b/freqtrade/freqai/data_drawer.py index d167a39eb..aecab0640 100644 --- a/freqtrade/freqai/data_drawer.py +++ b/freqtrade/freqai/data_drawer.py @@ -24,7 +24,6 @@ from freqtrade.exceptions import OperationalException from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.strategy.interface import IStrategy - logger = logging.getLogger(__name__) @@ -90,8 +89,8 @@ class FreqaiDataDrawer: self.metric_tracker_lock = threading.Lock() self.old_DBSCAN_eps: Dict[str, float] = {} self.empty_pair_dict: pair_info = { - "model_filename": "", "trained_timestamp": 0, - "data_path": "", "extras": {}} + "model_filename": "", "trained_timestamp": 0, + "data_path": "", "extras": {}} self.model_type = self.freqai_info.get('model_save_type', 'joblib') def update_metric_tracker(self, metric: str, value: float, pair: str) -> None: @@ -446,9 +445,9 @@ class FreqaiDataDrawer: dump(model, save_path / f"{dk.model_filename}_model.joblib") elif self.model_type == 'keras': model.save(save_path / f"{dk.model_filename}_model.h5") - elif 'stable_baselines' in self.model_type or\ - 'sb3_contrib' == self.model_type or\ - 'pytorch' == self.model_type: + elif ('stable_baselines' in self.model_type or + 'sb3_contrib' == self.model_type or + 'pytorch' == self.model_type): model.save(save_path / f"{dk.model_filename}_model.zip") if dk.svm_model is not None: @@ -581,16 +580,16 @@ class FreqaiDataDrawer: if len(df_dp.index) == 0: continue if str(hist_df.iloc[-1]["date"]) == str( - df_dp.iloc[-1:]["date"].iloc[-1] + df_dp.iloc[-1:]["date"].iloc[-1] ): continue try: index = ( - df_dp.loc[ - df_dp["date"] == hist_df.iloc[-1]["date"] - ].index[0] - + 1 + df_dp.loc[ + df_dp["date"] == hist_df.iloc[-1]["date"] + ].index[0] + + 1 ) except IndexError: if hist_df.iloc[-1]['date'] < df_dp['date'].iloc[0]: @@ -643,7 +642,7 @@ class FreqaiDataDrawer: ) def get_base_and_corr_dataframes( - self, timerange: TimeRange, pair: str, dk: FreqaiDataKitchen + self, timerange: TimeRange, pair: str, dk: FreqaiDataKitchen ) -> Tuple[Dict[Any, Any], Dict[Any, Any]]: """ Searches through our historic_data in memory and returns the dataframes relevant