2023-03-05 14:59:24 +00:00
|
|
|
import logging
|
|
|
|
|
2023-03-08 14:03:36 +00:00
|
|
|
import torch.nn as nn
|
2023-03-08 14:08:04 +00:00
|
|
|
from torch import Tensor
|
2023-03-05 14:59:24 +00:00
|
|
|
|
2023-03-08 14:10:25 +00:00
|
|
|
|
2023-03-05 14:59:24 +00:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2023-03-06 14:16:45 +00:00
|
|
|
class PyTorchMLPModel(nn.Module):
|
2023-03-12 12:31:08 +00:00
|
|
|
def __init__(self, input_dim: int, output_dim: int, **kwargs):
|
2023-03-06 14:16:45 +00:00
|
|
|
super(PyTorchMLPModel, self).__init__()
|
2023-03-12 12:31:08 +00:00
|
|
|
hidden_dim: int = kwargs.get("hidden_dim", 1024)
|
|
|
|
dropout_percent: int = kwargs.get("dropout_percent", 0.2)
|
|
|
|
n_layer: int = kwargs.get("n_layer", 1)
|
2023-03-05 14:59:24 +00:00
|
|
|
self.input_layer = nn.Linear(input_dim, hidden_dim)
|
2023-03-12 12:31:08 +00:00
|
|
|
self.blocks = nn.Sequential(*[Block(hidden_dim, dropout_percent) for _ in range(n_layer)])
|
2023-03-05 14:59:24 +00:00
|
|
|
self.output_layer = nn.Linear(hidden_dim, output_dim)
|
|
|
|
self.relu = nn.ReLU()
|
2023-03-12 12:31:08 +00:00
|
|
|
self.dropout = nn.Dropout(p=dropout_percent)
|
2023-03-05 14:59:24 +00:00
|
|
|
|
2023-03-06 17:14:54 +00:00
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
2023-03-05 14:59:24 +00:00
|
|
|
x = self.relu(self.input_layer(x))
|
|
|
|
x = self.dropout(x)
|
2023-03-19 15:03:36 +00:00
|
|
|
x = self.blocks(x)
|
2023-03-05 14:59:24 +00:00
|
|
|
logits = self.output_layer(x)
|
2023-03-06 14:16:45 +00:00
|
|
|
return logits
|
2023-03-12 12:31:08 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Block(nn.Module):
|
|
|
|
def __init__(self, hidden_dim: int, dropout_percent: int):
|
|
|
|
super(Block, self).__init__()
|
|
|
|
self.ff = FeedForward(hidden_dim)
|
|
|
|
self.dropout = nn.Dropout(p=dropout_percent)
|
|
|
|
self.ln = nn.LayerNorm(hidden_dim)
|
|
|
|
|
|
|
|
def forward(self, x):
|
2023-03-19 15:03:36 +00:00
|
|
|
x = self.ff(self.ln(x))
|
|
|
|
x = self.dropout(x)
|
2023-03-12 12:31:08 +00:00
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
|
def __init__(self, hidden_dim: int):
|
|
|
|
super(FeedForward, self).__init__()
|
|
|
|
self.net = nn.Sequential(
|
|
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|
|
|
nn.ReLU(),
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.net(x)
|