diff --git a/freqtrade/freqai/base_models/BasePyTorchModel.py b/freqtrade/freqai/base_models/BasePyTorchModel.py index 8e608ee1a..d6372fa36 100644 --- a/freqtrade/freqai/base_models/BasePyTorchModel.py +++ b/freqtrade/freqai/base_models/BasePyTorchModel.py @@ -19,9 +19,9 @@ class BasePyTorchModel(IFreqaiModel): """ def __init__(self, **kwargs): - super().__init__(config=kwargs['config']) - self.dd.model_type = 'pytorch' - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + super().__init__(config=kwargs["config"]) + self.dd.model_type = "pytorch" + self.device = "cuda" if torch.cuda.is_available() else "cpu" def train( self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs diff --git a/freqtrade/freqai/base_models/PyTorchModelTrainer.py b/freqtrade/freqai/base_models/PyTorchModelTrainer.py index a934814ef..0ca28d2e9 100644 --- a/freqtrade/freqai/base_models/PyTorchModelTrainer.py +++ b/freqtrade/freqai/base_models/PyTorchModelTrainer.py @@ -61,7 +61,7 @@ class PyTorchModelTrainer: """ data_loaders_dictionary = self.create_data_loaders_dictionary(data_dictionary) epochs = self.calc_n_epochs( - n_obs=len(data_dictionary['train_features']), + n_obs=len(data_dictionary["train_features"]), batch_size=self.batch_size, n_iters=self.max_iters ) @@ -73,7 +73,7 @@ class PyTorchModelTrainer: f" train loss {losses['train']:.4f} ; test loss {losses['test']:.4f}" ) # training - for batch_data in data_loaders_dictionary['train']: + for batch_data in data_loaders_dictionary["train"]: xb, yb = batch_data xb = xb.to(self.device) yb = yb.to(self.device) @@ -93,12 +93,12 @@ class PyTorchModelTrainer: self.model.eval() epochs = self.calc_n_epochs( - n_obs=len(data_dictionary['test_features']), + n_obs=len(data_dictionary["test_features"]), batch_size=self.batch_size, n_iters=self.eval_iters ) loss_dictionary = {} - for split in ['train', 'test']: + for split in ["train", "test"]: losses = torch.zeros(epochs) for i, batch in enumerate(data_loader_dictionary[split]): xb, yb = batch @@ -121,12 +121,12 @@ class PyTorchModelTrainer: Converts the input data to PyTorch tensors using a data loader. """ data_loader_dictionary = {} - for split in ['train', 'test']: - labels_shape = data_dictionary[f'{split}_labels'].shape + for split in ["train", "test"]: + labels_shape = data_dictionary[f"{split}_labels"].shape labels_view = labels_shape[0] if labels_shape[1] == 1 else labels_shape dataset = TensorDataset( - torch.from_numpy(data_dictionary[f'{split}_features'].values).float(), - torch.from_numpy(data_dictionary[f'{split}_labels'].astype(float).values) + torch.from_numpy(data_dictionary[f"{split}_features"].values).float(), + torch.from_numpy(data_dictionary[f"{split}_labels"].astype(float).values) .long() .view(labels_view) ) @@ -148,6 +148,7 @@ class PyTorchModelTrainer: Calculates the number of epochs required to reach the maximum number of iterations specified in the model training parameters. """ + n_batches = n_obs // batch_size epochs = n_iters // n_batches return epochs @@ -160,9 +161,9 @@ class PyTorchModelTrainer: """ torch.save({ - 'model_state_dict': self.model.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), - 'model_meta_data': self.model_meta_data, + "model_state_dict": self.model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + "model_meta_data": self.model_meta_data, }, path) def load_from_file(self, path: Path): diff --git a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py index e8326ffe9..a98643b3f 100644 --- a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py +++ b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py @@ -59,7 +59,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): self.init_class_names_to_index_mapping(self.class_names) self.encode_classes_name(data_dictionary, dk) - n_features = data_dictionary['train_features'].shape[-1] + n_features = data_dictionary["train_features"].shape[-1] model = PyTorchMLPModel( input_dim=n_features, hidden_dim=self.n_hidden,