type hints fixes

This commit is contained in:
Yinon Polak 2023-03-06 19:14:54 +02:00
parent 125085fbaf
commit 8acdd0b47c
2 changed files with 3 additions and 2 deletions

View File

@ -51,7 +51,7 @@ class PyTorchModelTrainer:
# training # training
for batch_data in data_loaders_dictionary['train']: for batch_data in data_loaders_dictionary['train']:
xb, yb = batch_data xb, yb = batch_data
xb = xb.to(self.device) # type: ignore xb = xb.to(self.device)
yb = yb.to(self.device) yb = yb.to(self.device)
yb_pred = self.model(xb) yb_pred = self.model(xb)
loss = self.criterion(yb_pred, yb) loss = self.criterion(yb_pred, yb)

View File

@ -3,6 +3,7 @@ import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -16,7 +17,7 @@ class PyTorchMLPModel(nn.Module):
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2) self.dropout = nn.Dropout(p=0.2)
def forward(self, x: torch.tensor) -> torch.tensor: def forward(self, x: Tensor) -> Tensor:
x = self.relu(self.input_layer(x)) x = self.relu(self.input_layer(x))
x = self.dropout(x) x = self.dropout(x)
x = self.relu(self.hidden_layer(x)) x = self.relu(self.hidden_layer(x))