add pytorch mlp models to test_start_backtesting
This commit is contained in:
		| @@ -83,6 +83,22 @@ def make_rl_config(conf): | ||||
|     return conf | ||||
|  | ||||
|  | ||||
| def mock_pytorch_mlp_model_training_parameters(conf): | ||||
|     return { | ||||
|             "learning_rate": 3e-4, | ||||
|             "trainer_kwargs": { | ||||
|                 "max_iters": 1, | ||||
|                 "batch_size": 64, | ||||
|                 "max_n_eval_batches": 1, | ||||
|             }, | ||||
|             "model_kwargs": { | ||||
|                 "hidden_dim": 32, | ||||
|                 "dropout_percent": 0.2, | ||||
|                 "n_layer": 1, | ||||
|             } | ||||
|         } | ||||
|  | ||||
|  | ||||
| def get_patched_data_kitchen(mocker, freqaiconf): | ||||
|     dk = FreqaiDataKitchen(freqaiconf) | ||||
|     return dk | ||||
|   | ||||
| @@ -89,19 +89,8 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca, | ||||
|  | ||||
|     if 'PyTorchMLPRegressor' in model: | ||||
|         model_save_ext = 'zip' | ||||
|         freqai_conf['freqai']['model_training_parameters'].update({ | ||||
|             "learning_rate": 3e-4, | ||||
|             "trainer_kwargs": { | ||||
|                 "max_iters": 1, | ||||
|                 "batch_size": 64, | ||||
|                 "max_n_eval_batches": 1, | ||||
|             }, | ||||
|             "model_kwargs": { | ||||
|                 "hidden_dim": 32, | ||||
|                 "dropout_percent": 0.2, | ||||
|                 "n_layer": 1, | ||||
|             } | ||||
|         }) | ||||
|         pytorch_mlp_mtp = mock_pytorch_mlp_model_training_parameters() | ||||
|         freqai_conf['freqai']['model_training_parameters'].update(pytorch_mlp_mtp) | ||||
|  | ||||
|     strategy = get_patched_freqai_strategy(mocker, freqai_conf) | ||||
|     exchange = get_patched_exchange(mocker, freqai_conf) | ||||
| @@ -214,19 +203,8 @@ def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model): | ||||
|                                         strategy, freqai.dk, data_load_timerange) | ||||
|  | ||||
|     if 'PyTorchMLPClassifier': | ||||
|         freqai_conf['freqai']['model_training_parameters'].update({ | ||||
|             "learning_rate": 3e-4, | ||||
|             "trainer_kwargs": { | ||||
|                 "max_iters": 1, | ||||
|                 "batch_size": 64, | ||||
|                 "max_n_eval_batches": 1, | ||||
|             }, | ||||
|             "model_kwargs": { | ||||
|                 "hidden_dim": 32, | ||||
|                 "dropout_percent": 0.2, | ||||
|                 "n_layer": 1, | ||||
|             } | ||||
|         }) | ||||
|         pytorch_mlp_mtp = mock_pytorch_mlp_model_training_parameters() | ||||
|         freqai_conf['freqai']['model_training_parameters'].update(pytorch_mlp_mtp) | ||||
|  | ||||
|     if freqai.dd.model_type == 'joblib': | ||||
|         model_file_extension = ".joblib" | ||||
| @@ -251,10 +229,12 @@ def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model): | ||||
|         ("LightGBMRegressor", 2, "freqai_test_strat"), | ||||
|         ("XGBoostRegressor", 2, "freqai_test_strat"), | ||||
|         ("CatboostRegressor", 2, "freqai_test_strat"), | ||||
|         ("PyTorchMLPRegressor", 2, "freqai_test_strat"), | ||||
|         ("ReinforcementLearner", 3, "freqai_rl_test_strat"), | ||||
|         ("XGBoostClassifier", 2, "freqai_test_classifier"), | ||||
|         ("LightGBMClassifier", 2, "freqai_test_classifier"), | ||||
|         ("CatboostClassifier", 2, "freqai_test_classifier") | ||||
|         ("CatboostClassifier", 2, "freqai_test_classifier"), | ||||
|         ("PyTorchMLPClassifier", 2, "freqai_test_classifier") | ||||
|     ], | ||||
|     ) | ||||
| def test_start_backtesting(mocker, freqai_conf, model, num_files, strat, caplog): | ||||
| @@ -275,6 +255,10 @@ def test_start_backtesting(mocker, freqai_conf, model, num_files, strat, caplog) | ||||
|     if 'test_4ac' in model: | ||||
|         freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models") | ||||
|  | ||||
|     if 'PyTorchMLP' in model: | ||||
|         pytorch_mlp_mtp = mock_pytorch_mlp_model_training_parameters() | ||||
|         freqai_conf['freqai']['model_training_parameters'].update(pytorch_mlp_mtp) | ||||
|  | ||||
|     freqai_conf.get("freqai", {}).get("feature_parameters", {}).update( | ||||
|         {"indicator_periods_candles": [2]}) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user