type hints fixes
This commit is contained in:
parent
125085fbaf
commit
8acdd0b47c
@ -51,7 +51,7 @@ class PyTorchModelTrainer:
|
|||||||
# training
|
# training
|
||||||
for batch_data in data_loaders_dictionary['train']:
|
for batch_data in data_loaders_dictionary['train']:
|
||||||
xb, yb = batch_data
|
xb, yb = batch_data
|
||||||
xb = xb.to(self.device) # type: ignore
|
xb = xb.to(self.device)
|
||||||
yb = yb.to(self.device)
|
yb = yb.to(self.device)
|
||||||
yb_pred = self.model(xb)
|
yb_pred = self.model(xb)
|
||||||
loss = self.criterion(yb_pred, yb)
|
loss = self.criterion(yb_pred, yb)
|
||||||
|
@ -3,6 +3,7 @@ import logging
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -16,7 +17,7 @@ class PyTorchMLPModel(nn.Module):
|
|||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.dropout = nn.Dropout(p=0.2)
|
self.dropout = nn.Dropout(p=0.2)
|
||||||
|
|
||||||
def forward(self, x: torch.tensor) -> torch.tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
x = self.relu(self.input_layer(x))
|
x = self.relu(self.input_layer(x))
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
x = self.relu(self.hidden_layer(x))
|
x = self.relu(self.hidden_layer(x))
|
||||||
|
Loading…
Reference in New Issue
Block a user