add pytorch data convertor
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Tuple, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -46,7 +47,8 @@ class PyTorchMLPModel(nn.Module):
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(p=dropout_percent)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
|
||||
x, = x
|
||||
x = self.relu(self.input_layer(x))
|
||||
x = self.dropout(x)
|
||||
x = self.blocks(x)
|
||||
|
||||
Reference in New Issue
Block a user