diff --git a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py index 3abc56fb1..edafb3b7a 100644 --- a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py +++ b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py @@ -35,7 +35,6 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): super().__init__(**kwargs) model_training_params = self.freqai_info.get("model_training_parameters", {}) - self.n_hidden: int = model_training_params.get("n_hidden", 1024) self.max_iters: int = model_training_params.get("max_iters", 100) self.batch_size: int = model_training_params.get("batch_size", 64) self.learning_rate: float = model_training_params.get("learning_rate", 3e-4) diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index da3c28de8..3b31012b2 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -89,13 +89,12 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, if 'PyTorchClassifierMultiTarget' in model: model_save_ext = 'zip' freqai_conf['freqai']['model_training_parameters'].update({ - "n_hidden": 1024, - "max_iters": 100, + "max_iters": 1, "batch_size": 64, "learning_rate": 3e-4, "max_n_eval_batches": None, "model_kwargs": { - "hidden_dim": 1024, + "hidden_dim": 32, "dropout_percent": 0.2, "n_layer": 1, }