type hints fixes
This commit is contained in:
@@ -3,6 +3,7 @@ import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,7 +17,7 @@ class PyTorchMLPModel(nn.Module):
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(p=0.2)
|
||||
|
||||
def forward(self, x: torch.tensor) -> torch.tensor:
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.relu(self.input_layer(x))
|
||||
x = self.dropout(x)
|
||||
x = self.relu(self.hidden_layer(x))
|
||||
|
Reference in New Issue
Block a user