type hints fixes
This commit is contained in:
		| @@ -51,7 +51,7 @@ class PyTorchModelTrainer: | ||||
|             # training | ||||
|             for batch_data in data_loaders_dictionary['train']: | ||||
|                 xb, yb = batch_data | ||||
|                 xb = xb.to(self.device)  # type: ignore | ||||
|                 xb = xb.to(self.device) | ||||
|                 yb = yb.to(self.device) | ||||
|                 yb_pred = self.model(xb) | ||||
|                 loss = self.criterion(yb_pred, yb) | ||||
|   | ||||
| @@ -3,6 +3,7 @@ import logging | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torch import Tensor | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| @@ -16,7 +17,7 @@ class PyTorchMLPModel(nn.Module): | ||||
|         self.relu = nn.ReLU() | ||||
|         self.dropout = nn.Dropout(p=0.2) | ||||
|  | ||||
|     def forward(self, x: torch.tensor) -> torch.tensor: | ||||
|     def forward(self, x: Tensor) -> Tensor: | ||||
|         x = self.relu(self.input_layer(x)) | ||||
|         x = self.dropout(x) | ||||
|         x = self.relu(self.hidden_layer(x)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user