diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index f8bee3659..da3c28de8 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -52,7 +52,8 @@ def can_run_model(model: str) -> None: ('ReinforcementLearner_multiproc', False, False, False, True, False, 0), ('ReinforcementLearner_test_3ac', False, False, False, False, False, 0), ('ReinforcementLearner_test_3ac', False, False, False, True, False, 0), - ('ReinforcementLearner_test_4ac', False, False, False, True, False, 0) + ('ReinforcementLearner_test_4ac', False, False, False, True, False, 0), + ('PyTorchClassifierMultiTarget', False, False, False, True, False, 0) ]) def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, dbscan, float32, can_short, shuffle, buffer): @@ -85,6 +86,21 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, if 'test_3ac' in model or 'test_4ac' in model: freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models") + if 'PyTorchClassifierMultiTarget' in model: + model_save_ext = 'zip' + freqai_conf['freqai']['model_training_parameters'].update({ + "n_hidden": 1024, + "max_iters": 100, + "batch_size": 64, + "learning_rate": 3e-4, + "max_n_eval_batches": None, + "model_kwargs": { + "hidden_dim": 1024, + "dropout_percent": 0.2, + "n_layer": 1, + } + }) + strategy = get_patched_freqai_strategy(mocker, freqai_conf) exchange = get_patched_exchange(mocker, freqai_conf) strategy.dp = DataProvider(freqai_conf, exchange)