From 8acdd0b47c8cb7239933653b393460a267f39501 Mon Sep 17 00:00:00 2001 From: Yinon Polak Date: Mon, 6 Mar 2023 19:14:54 +0200 Subject: [PATCH] type hints fixes --- freqtrade/freqai/base_models/PyTorchModelTrainer.py | 2 +- freqtrade/freqai/prediction_models/PyTorchMLPModel.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/freqtrade/freqai/base_models/PyTorchModelTrainer.py b/freqtrade/freqai/base_models/PyTorchModelTrainer.py index 992ad37ef..52fb0ceb5 100644 --- a/freqtrade/freqai/base_models/PyTorchModelTrainer.py +++ b/freqtrade/freqai/base_models/PyTorchModelTrainer.py @@ -51,7 +51,7 @@ class PyTorchModelTrainer: # training for batch_data in data_loaders_dictionary['train']: xb, yb = batch_data - xb = xb.to(self.device) # type: ignore + xb = xb.to(self.device) yb = yb.to(self.device) yb_pred = self.model(xb) loss = self.criterion(yb_pred, yb) diff --git a/freqtrade/freqai/prediction_models/PyTorchMLPModel.py b/freqtrade/freqai/prediction_models/PyTorchMLPModel.py index 4e1cc32ba..9bbf95019 100644 --- a/freqtrade/freqai/prediction_models/PyTorchMLPModel.py +++ b/freqtrade/freqai/prediction_models/PyTorchMLPModel.py @@ -3,6 +3,7 @@ import logging import torch import torch.nn as nn +from torch import Tensor logger = logging.getLogger(__name__) @@ -16,7 +17,7 @@ class PyTorchMLPModel(nn.Module): self.relu = nn.ReLU() self.dropout = nn.Dropout(p=0.2) - def forward(self, x: torch.tensor) -> torch.tensor: + def forward(self, x: Tensor) -> Tensor: x = self.relu(self.input_layer(x)) x = self.dropout(x) x = self.relu(self.hidden_layer(x))