generalize mlp model

This commit is contained in:
Yinon Polak 2023-03-12 14:31:08 +02:00
parent 1cf0e7be24
commit f9fdf1c31b
2 changed files with 43 additions and 14 deletions

View File

@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Optional
import numpy as np
import numpy.typing as npt
@ -34,12 +34,13 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
"""
super().__init__(**kwargs)
model_training_parameters = self.freqai_info["model_training_parameters"]
self.n_hidden = model_training_parameters.get("n_hidden", 1024)
self.max_iters = model_training_parameters.get("max_iters", 100)
self.batch_size = model_training_parameters.get("batch_size", 64)
self.learning_rate = model_training_parameters.get("learning_rate", 3e-4)
self.max_n_eval_batches = model_training_parameters.get("max_n_eval_batches", None)
trainer_kwargs = self.freqai_info.get("trainer_kwargs", {})
self.n_hidden: int = trainer_kwargs.get("n_hidden", 1024)
self.max_iters: int = trainer_kwargs.get("max_iters", 100)
self.batch_size: int = trainer_kwargs.get("batch_size", 64)
self.learning_rate: float = trainer_kwargs.get("learning_rate", 3e-4)
self.max_n_eval_batches: Optional[int] = trainer_kwargs.get("max_n_eval_batches", None)
self.model_kwargs: Dict = trainer_kwargs.get("model_kwargs", {})
self.class_name_to_index = None
self.index_to_class_name = None
@ -64,8 +65,8 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
n_features = data_dictionary["train_features"].shape[-1]
model = PyTorchMLPModel(
input_dim=n_features,
hidden_dim=self.n_hidden,
output_dim=len(self.class_names)
output_dim=len(self.class_names),
**self.model_kwargs
)
model.to(self.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)

View File

@ -8,18 +8,46 @@ logger = logging.getLogger(__name__)
class PyTorchMLPModel(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
def __init__(self, input_dim: int, output_dim: int, **kwargs):
super(PyTorchMLPModel, self).__init__()
hidden_dim: int = kwargs.get("hidden_dim", 1024)
dropout_percent: int = kwargs.get("dropout_percent", 0.2)
n_layer: int = kwargs.get("n_layer", 1)
self.input_layer = nn.Linear(input_dim, hidden_dim)
self.hidden_layer = nn.Linear(hidden_dim, hidden_dim)
self.blocks = nn.Sequential(*[Block(hidden_dim, dropout_percent) for _ in range(n_layer)])
self.output_layer = nn.Linear(hidden_dim, output_dim)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2)
self.dropout = nn.Dropout(p=dropout_percent)
def forward(self, x: Tensor) -> Tensor:
x = self.relu(self.input_layer(x))
x = self.dropout(x)
x = self.relu(self.hidden_layer(x))
x = self.dropout(x)
x = self.relu(self.blocks(x))
logits = self.output_layer(x)
return logits
class Block(nn.Module):
def __init__(self, hidden_dim: int, dropout_percent: int):
super(Block, self).__init__()
self.ff = FeedForward(hidden_dim)
self.dropout = nn.Dropout(p=dropout_percent)
self.ln = nn.LayerNorm(hidden_dim)
def forward(self, x):
x = self.dropout(self.ff(x))
x = self.ln(x)
return x
class FeedForward(nn.Module):
def __init__(self, hidden_dim: int):
super(FeedForward, self).__init__()
self.net = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
def forward(self, x):
return self.net(x)