From 0510cf44910d4980eb66ecf2a5f6947607a7a8f9 Mon Sep 17 00:00:00 2001 From: Yinon Polak Date: Mon, 20 Mar 2023 18:08:38 +0200 Subject: [PATCH] add config params to tests --- tests/freqai/test_freqai_interface.py | 30 +++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/freqai/test_freqai_interface.py b/tests/freqai/test_freqai_interface.py index 3407a5a95..d35b00013 100644 --- a/tests/freqai/test_freqai_interface.py +++ b/tests/freqai/test_freqai_interface.py @@ -88,6 +88,19 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, if 'MLPPyTorchRegressor' in model: model_save_ext = 'zip' + freqai_conf['freqai']['model_training_parameters'].update({ + "learning_rate": 3e-4, + "trainer_kwargs": { + "max_iters": 1, + "batch_size": 64, + "max_n_eval_batches": 1, + }, + "model_kwargs": { + "hidden_dim": 32, + "dropout_percent": 0.2, + "n_layer": 1, + } + }) strategy = get_patched_freqai_strategy(mocker, freqai_conf) exchange = get_patched_exchange(mocker, freqai_conf) @@ -200,6 +213,23 @@ def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model): freqai.extract_data_and_train_model(new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange) + + if 'MLPPyTorchClassifier': + freqai_conf['freqai']['model_training_parameters'].update({ + "learning_rate": 3e-4, + "trainer_kwargs": { + "max_iters": 1, + "batch_size": 64, + "max_n_eval_batches": 1, + }, + "model_kwargs": { + "hidden_dim": 32, + "dropout_percent": 0.2, + "n_layer": 1, + } + }) + + if freqai.dd.model_type == 'joblib': model_file_extension = ".joblib" elif freqai.dd.model_type == "pytorch":