clean code
This commit is contained in:
parent
bf4aa91aab
commit
6b4d9f97c1
@ -36,7 +36,7 @@ class PyTorchMLPModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_dim: int, output_dim: int, **kwargs):
|
def __init__(self, input_dim: int, output_dim: int, **kwargs):
|
||||||
super(PyTorchMLPModel, self).__init__()
|
super().__init__()
|
||||||
hidden_dim: int = kwargs.get("hidden_dim", 256)
|
hidden_dim: int = kwargs.get("hidden_dim", 256)
|
||||||
dropout_percent: int = kwargs.get("dropout_percent", 0.2)
|
dropout_percent: int = kwargs.get("dropout_percent", 0.2)
|
||||||
n_layer: int = kwargs.get("n_layer", 1)
|
n_layer: int = kwargs.get("n_layer", 1)
|
||||||
@ -65,7 +65,7 @@ class Block(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hidden_dim: int, dropout_percent: int):
|
def __init__(self, hidden_dim: int, dropout_percent: int):
|
||||||
super(Block, self).__init__()
|
super().__init__()
|
||||||
self.ff = FeedForward(hidden_dim)
|
self.ff = FeedForward(hidden_dim)
|
||||||
self.dropout = nn.Dropout(p=dropout_percent)
|
self.dropout = nn.Dropout(p=dropout_percent)
|
||||||
self.ln = nn.LayerNorm(hidden_dim)
|
self.ln = nn.LayerNorm(hidden_dim)
|
||||||
@ -85,7 +85,7 @@ class FeedForward(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hidden_dim: int):
|
def __init__(self, hidden_dim: int):
|
||||||
super(FeedForward, self).__init__()
|
super().__init__()
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
nn.Linear(hidden_dim, hidden_dim),
|
nn.Linear(hidden_dim, hidden_dim),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
|
@ -47,4 +47,4 @@ class PyTorchRegressor(BasePyTorchModel):
|
|||||||
|
|
||||||
y = self.model.model(x)
|
y = self.model.model(x)
|
||||||
pred_df = DataFrame(y.detach().numpy(), columns=[dk.label_list[0]])
|
pred_df = DataFrame(y.detach().numpy(), columns=[dk.label_list[0]])
|
||||||
return (pred_df, dk.do_predict)
|
return (pred_df, dk.do_predict)
|
||||||
|
@ -229,7 +229,6 @@ def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model):
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
if freqai.dd.model_type == 'joblib':
|
if freqai.dd.model_type == 'joblib':
|
||||||
model_file_extension = ".joblib"
|
model_file_extension = ".joblib"
|
||||||
elif freqai.dd.model_type == "pytorch":
|
elif freqai.dd.model_type == "pytorch":
|
||||||
|
Loading…
Reference in New Issue
Block a user