add pytorch data convertor

This commit is contained in:
Yinon Polak
2023-04-03 15:19:10 +03:00
parent 5a7ca35c6b
commit bd3b70293f
9 changed files with 168 additions and 40 deletions

View File

@@ -1,4 +1,5 @@
import logging
from typing import Tuple, List
import torch
import torch.nn as nn
@@ -46,7 +47,8 @@ class PyTorchMLPModel(nn.Module):
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=dropout_percent)
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
x, = x
x = self.relu(self.input_layer(x))
x = self.dropout(x)
x = self.blocks(x)