From 6b4d9f97c13472267c5b6d7ef920eafd8001acb3 Mon Sep 17 00:00:00 2001 From: Yinon Polak Date: Mon, 20 Mar 2023 19:28:30 +0200 Subject: [PATCH] clean code --- freqtrade/freqai/prediction_models/PyTorchMLPModel.py | 6 +++--- freqtrade/freqai/prediction_models/PyTorchRegressor.py | 2 +- tests/freqai/test_freqai_interface.py | 1 - 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/freqtrade/freqai/prediction_models/PyTorchMLPModel.py b/freqtrade/freqai/prediction_models/PyTorchMLPModel.py index a9f609e8e..22fb9c3f0 100644 --- a/freqtrade/freqai/prediction_models/PyTorchMLPModel.py +++ b/freqtrade/freqai/prediction_models/PyTorchMLPModel.py @@ -36,7 +36,7 @@ class PyTorchMLPModel(nn.Module): """ def __init__(self, input_dim: int, output_dim: int, **kwargs): - super(PyTorchMLPModel, self).__init__() + super().__init__() hidden_dim: int = kwargs.get("hidden_dim", 256) dropout_percent: int = kwargs.get("dropout_percent", 0.2) n_layer: int = kwargs.get("n_layer", 1) @@ -65,7 +65,7 @@ class Block(nn.Module): """ def __init__(self, hidden_dim: int, dropout_percent: int): - super(Block, self).__init__() + super().__init__() self.ff = FeedForward(hidden_dim) self.dropout = nn.Dropout(p=dropout_percent) self.ln = nn.LayerNorm(hidden_dim) @@ -85,7 +85,7 @@ class FeedForward(nn.Module): """ def __init__(self, hidden_dim: int): - super(FeedForward, self).__init__() + super().__init__() self.net = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), diff --git a/freqtrade/freqai/prediction_models/PyTorchRegressor.py b/freqtrade/freqai/prediction_models/PyTorchRegressor.py index 837fbd836..440db96b9 100644 --- a/freqtrade/freqai/prediction_models/PyTorchRegressor.py +++ b/freqtrade/freqai/prediction_models/PyTorchRegressor.py @@ -47,4 +47,4 @@ class PyTorchRegressor(BasePyTorchModel): y = self.model.model(x) pred_df = DataFrame(y.detach().numpy(), columns=[dk.label_list[0]]) - return (pred_df, dk.do_predict) \ No newline at end of file + return (pred_df, dk.do_predict) diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index 7931dc7a4..c1d9998d6 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -229,7 +229,6 @@ def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model): } }) - if freqai.dd.model_type == 'joblib': model_file_extension = ".joblib" elif freqai.dd.model_type == "pytorch":