stable/freqtrade/freqai/prediction_models/PyTorchMLPModel.py

92 lines
3.5 KiB
Python
Raw Normal View History

2023-03-05 14:59:24 +00:00
import logging
2023-03-08 14:03:36 +00:00
import torch.nn as nn
2023-03-19 15:45:30 +00:00
import torch
2023-03-08 14:10:25 +00:00
2023-03-05 14:59:24 +00:00
logger = logging.getLogger(__name__)
class PyTorchMLPModel(nn.Module):
2023-03-19 15:45:30 +00:00
"""
A multi-layer perceptron (MLP) model implemented using PyTorch.
2023-03-19 16:04:01 +00:00
:param input_dim: The number of input features. This parameter specifies the number
of features in the input data that the MLP will use to make predictions.
:param output_dim: The number of output classes. This parameter specifies the number
of classes that the MLP will predict.
:param hidden_dim: The number of hidden units in each layer. This parameter controls
the complexity of the MLP and determines how many nonlinear relationships the MLP
can represent. Increasing the number of hidden units can increase the capacity of
the MLP to model complex patterns, but it also increases the risk of overfitting
the training data. Default: 256
:param dropout_percent: The dropout rate for regularization. This parameter specifies
the probability of dropping out a neuron during training to prevent overfitting.
The dropout rate should be tuned carefully to balance between underfitting and
overfitting. Default: 0.2
:param n_layer: The number of layers in the MLP. This parameter specifies the number
of layers in the MLP architecture. Adding more layers to the MLP can increase its
capacity to model complex patterns, but it also increases the risk of overfitting
the training data. Default: 1
2023-03-19 15:45:30 +00:00
:returns: The output of the MLP, with shape (batch_size, output_dim)
"""
2023-03-12 12:31:08 +00:00
def __init__(self, input_dim: int, output_dim: int, **kwargs):
super(PyTorchMLPModel, self).__init__()
2023-03-19 15:45:30 +00:00
hidden_dim: int = kwargs.get("hidden_dim", 256)
2023-03-12 12:31:08 +00:00
dropout_percent: int = kwargs.get("dropout_percent", 0.2)
n_layer: int = kwargs.get("n_layer", 1)
2023-03-05 14:59:24 +00:00
self.input_layer = nn.Linear(input_dim, hidden_dim)
2023-03-12 12:31:08 +00:00
self.blocks = nn.Sequential(*[Block(hidden_dim, dropout_percent) for _ in range(n_layer)])
2023-03-05 14:59:24 +00:00
self.output_layer = nn.Linear(hidden_dim, output_dim)
self.relu = nn.ReLU()
2023-03-12 12:31:08 +00:00
self.dropout = nn.Dropout(p=dropout_percent)
2023-03-05 14:59:24 +00:00
2023-03-19 15:45:30 +00:00
def forward(self, x: torch.Tensor) -> torch.Tensor:
2023-03-05 14:59:24 +00:00
x = self.relu(self.input_layer(x))
x = self.dropout(x)
x = self.blocks(x)
2023-03-05 14:59:24 +00:00
logits = self.output_layer(x)
return logits
2023-03-12 12:31:08 +00:00
class Block(nn.Module):
2023-03-19 15:45:30 +00:00
"""
2023-03-19 16:04:01 +00:00
A building block for a multi-layer perceptron (MLP).
2023-03-19 15:45:30 +00:00
:param hidden_dim: The number of hidden units in the feedforward network.
:param dropout_percent: The dropout rate for regularization.
:returns: torch.Tensor. with shape (batch_size, hidden_dim)
"""
2023-03-12 12:31:08 +00:00
def __init__(self, hidden_dim: int, dropout_percent: int):
super(Block, self).__init__()
self.ff = FeedForward(hidden_dim)
self.dropout = nn.Dropout(p=dropout_percent)
self.ln = nn.LayerNorm(hidden_dim)
2023-03-19 15:45:30 +00:00
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.ff(self.ln(x))
x = self.dropout(x)
2023-03-12 12:31:08 +00:00
return x
class FeedForward(nn.Module):
2023-03-19 15:45:30 +00:00
"""
2023-03-19 16:04:01 +00:00
A simple fully-connected feedforward neural network block.
2023-03-19 15:45:30 +00:00
:param hidden_dim: The number of hidden units in the block.
:return: torch.Tensor. with shape (batch_size, hidden_dim)
"""
2023-03-12 12:31:08 +00:00
def __init__(self, hidden_dim: int):
super(FeedForward, self).__init__()
self.net = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
2023-03-19 15:45:30 +00:00
def forward(self, x: torch.Tensor) -> torch.Tensor:
2023-03-12 12:31:08 +00:00
return self.net(x)