pytorch mlp rename input to fix mypy error
This commit is contained in:
parent
26738370c7
commit
a655524221
@ -47,8 +47,8 @@ class PyTorchMLPModel(nn.Module):
|
|||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.dropout = nn.Dropout(p=dropout_percent)
|
self.dropout = nn.Dropout(p=dropout_percent)
|
||||||
|
|
||||||
def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
|
def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor:
|
||||||
x: torch.Tensor = x[0]
|
x: torch.Tensor = tensors[0]
|
||||||
x = self.relu(self.input_layer(x))
|
x = self.relu(self.input_layer(x))
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
x = self.blocks(x)
|
x = self.blocks(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user