clean code

This commit is contained in:
Yinon Polak 2023-03-20 19:28:30 +02:00
parent bf4aa91aab
commit 6b4d9f97c1
3 changed files with 4 additions and 5 deletions

View File

@ -36,7 +36,7 @@ class PyTorchMLPModel(nn.Module):
""" """
def __init__(self, input_dim: int, output_dim: int, **kwargs): def __init__(self, input_dim: int, output_dim: int, **kwargs):
super(PyTorchMLPModel, self).__init__() super().__init__()
hidden_dim: int = kwargs.get("hidden_dim", 256) hidden_dim: int = kwargs.get("hidden_dim", 256)
dropout_percent: int = kwargs.get("dropout_percent", 0.2) dropout_percent: int = kwargs.get("dropout_percent", 0.2)
n_layer: int = kwargs.get("n_layer", 1) n_layer: int = kwargs.get("n_layer", 1)
@ -65,7 +65,7 @@ class Block(nn.Module):
""" """
def __init__(self, hidden_dim: int, dropout_percent: int): def __init__(self, hidden_dim: int, dropout_percent: int):
super(Block, self).__init__() super().__init__()
self.ff = FeedForward(hidden_dim) self.ff = FeedForward(hidden_dim)
self.dropout = nn.Dropout(p=dropout_percent) self.dropout = nn.Dropout(p=dropout_percent)
self.ln = nn.LayerNorm(hidden_dim) self.ln = nn.LayerNorm(hidden_dim)
@ -85,7 +85,7 @@ class FeedForward(nn.Module):
""" """
def __init__(self, hidden_dim: int): def __init__(self, hidden_dim: int):
super(FeedForward, self).__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(), nn.ReLU(),

View File

@ -229,7 +229,6 @@ def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model):
} }
}) })
if freqai.dd.model_type == 'joblib': if freqai.dd.model_type == 'joblib':
model_file_extension = ".joblib" model_file_extension = ".joblib"
elif freqai.dd.model_type == "pytorch": elif freqai.dd.model_type == "pytorch":