diff --git a/tests/freqai/conftest.py b/tests/freqai/conftest.py index 68e7ea49a..02cfdd882 100644 --- a/tests/freqai/conftest.py +++ b/tests/freqai/conftest.py @@ -83,6 +83,22 @@ def make_rl_config(conf): return conf +def mock_pytorch_mlp_model_training_parameters(conf): + return { + "learning_rate": 3e-4, + "trainer_kwargs": { + "max_iters": 1, + "batch_size": 64, + "max_n_eval_batches": 1, + }, + "model_kwargs": { + "hidden_dim": 32, + "dropout_percent": 0.2, + "n_layer": 1, + } + } + + def get_patched_data_kitchen(mocker, freqaiconf): dk = FreqaiDataKitchen(freqaiconf) return dk diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index b4d808af2..5b460cda1 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -89,19 +89,8 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, if 'PyTorchMLPRegressor' in model: model_save_ext = 'zip' - freqai_conf['freqai']['model_training_parameters'].update({ - "learning_rate": 3e-4, - "trainer_kwargs": { - "max_iters": 1, - "batch_size": 64, - "max_n_eval_batches": 1, - }, - "model_kwargs": { - "hidden_dim": 32, - "dropout_percent": 0.2, - "n_layer": 1, - } - }) + pytorch_mlp_mtp = mock_pytorch_mlp_model_training_parameters() + freqai_conf['freqai']['model_training_parameters'].update(pytorch_mlp_mtp) strategy = get_patched_freqai_strategy(mocker, freqai_conf) exchange = get_patched_exchange(mocker, freqai_conf) @@ -214,19 +203,8 @@ def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model): strategy, freqai.dk, data_load_timerange) if 'PyTorchMLPClassifier': - freqai_conf['freqai']['model_training_parameters'].update({ - "learning_rate": 3e-4, - "trainer_kwargs": { - "max_iters": 1, - "batch_size": 64, - "max_n_eval_batches": 1, - }, - "model_kwargs": { - "hidden_dim": 32, - "dropout_percent": 0.2, - "n_layer": 1, - } - }) + pytorch_mlp_mtp = mock_pytorch_mlp_model_training_parameters() + freqai_conf['freqai']['model_training_parameters'].update(pytorch_mlp_mtp) if freqai.dd.model_type == 'joblib': model_file_extension = ".joblib" @@ -251,10 +229,12 @@ def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model): ("LightGBMRegressor", 2, "freqai_test_strat"), ("XGBoostRegressor", 2, "freqai_test_strat"), ("CatboostRegressor", 2, "freqai_test_strat"), + ("PyTorchMLPRegressor", 2, "freqai_test_strat"), ("ReinforcementLearner", 3, "freqai_rl_test_strat"), ("XGBoostClassifier", 2, "freqai_test_classifier"), ("LightGBMClassifier", 2, "freqai_test_classifier"), - ("CatboostClassifier", 2, "freqai_test_classifier") + ("CatboostClassifier", 2, "freqai_test_classifier"), + ("PyTorchMLPClassifier", 2, "freqai_test_classifier") ], ) def test_start_backtesting(mocker, freqai_conf, model, num_files, strat, caplog): @@ -275,6 +255,10 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat, caplog) if 'test_4ac' in model: freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models") + if 'PyTorchMLP' in model: + pytorch_mlp_mtp = mock_pytorch_mlp_model_training_parameters() + freqai_conf['freqai']['model_training_parameters'].update(pytorch_mlp_mtp) + freqai_conf.get("freqai", {}).get("feature_parameters", {}).update( {"indicator_periods_candles": [2]})