increase test coverage for RL and FreqAI

This commit is contained in:
robcaulk
2022-09-23 18:04:43 +02:00
parent 95121550ef
commit 9c361f4422
3 changed files with 61 additions and 146 deletions

View File

@@ -4,13 +4,15 @@ from pathlib import Path
from unittest.mock import MagicMock
import pytest
from freqtrade.enums import RunMode
from freqtrade.configuration import TimeRange
from freqtrade.data.dataprovider import DataProvider
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.plugins.pairlistmanager import PairListManager
from tests.conftest import get_patched_exchange, log_has_re
from tests.freqai.conftest import get_patched_freqai_strategy
from freqtrade.persistence import Trade
from freqtrade.freqai.utils import download_all_data_for_training, get_required_data_timerange
def is_arm() -> bool:
@@ -173,29 +175,34 @@ def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model):
shutil.rmtree(Path(freqai.dk.full_path))
@pytest.mark.parametrize('model', [
'LightGBMRegressor',
'XGBoostRegressor',
'CatboostRegressor',
'ReinforcementLearner'
])
def test_start_backtesting(mocker, freqai_conf, model):
@pytest.mark.parametrize(
"model, num_files, strat",
[
("LightGBMRegressor", 6, "freqai_test_strat"),
("XGBoostRegressor", 6, "freqai_test_strat"),
("CatboostRegressor", 6, "freqai_test_strat"),
("ReinforcementLearner", 7, "freqai_rl_test_strat"),
("XGBoostClassifier", 6, "freqai_test_classifier"),
("LightGBMClassifier", 6, "freqai_test_classifier"),
("CatboostClassifier", 6, "freqai_test_classifier")
],
)
def test_start_backtesting(mocker, freqai_conf, model, num_files, strat):
freqai_conf.get("freqai", {}).update({"save_backtest_models": True})
if is_arm() and model == 'CatboostRegressor':
freqai_conf['runmode'] = RunMode.BACKTEST
Trade.use_db = False
if is_arm() and "Catboost" in model:
pytest.skip("CatBoost is not supported on ARM")
if is_mac():
pytest.skip("Reinforcement learning module not available on intel based Mac OS")
model_save_ext = 'joblib'
freqai_conf.update({"freqaimodel": model})
freqai_conf.update({"timerange": "20180110-20180130"})
freqai_conf.update({"strategy": "freqai_test_strat"})
freqai_conf.update({"timerange": "20180120-20180130"})
freqai_conf.update({"strategy": strat})
if 'ReinforcementLearner' in model:
model_save_ext = 'zip'
freqai_conf.update({"strategy": "freqai_rl_test_strat"})
freqai_conf["freqai"].update({"model_training_parameters": {
"learning_rate": 0.00025,
"gamma": 0.9,
@@ -217,8 +224,7 @@ def test_start_backtesting(mocker, freqai_conf, model):
if 'test_4ac' in model:
freqai_conf["freqaimodel_path"] = str(Path(__file__).parents[1] / "freqai" / "test_models")
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
exchange = get_patched_exchange(mocker, freqai_conf)
strategy.dp = DataProvider(freqai_conf, exchange)
@@ -237,7 +243,7 @@ def test_start_backtesting(mocker, freqai_conf, model):
freqai.start_backtesting(df, metadata, freqai.dk)
model_folders = [x for x in freqai.dd.full_path.iterdir() if x.is_dir()]
assert len(model_folders) == 6
assert len(model_folders) == num_files
shutil.rmtree(Path(freqai.dk.full_path))
@@ -455,3 +461,40 @@ def test_freqai_informative_pairs(mocker, freqai_conf, timeframes, corr_pairs):
pairs_b = strategy.gather_informative_pairs()
# we expect unique pairs * timeframes
assert len(pairs_b) == len(set(pairlist + corr_pairs)) * len(timeframes)
def test_start_set_train_queue(mocker, freqai_conf, caplog):
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
exchange = get_patched_exchange(mocker, freqai_conf)
pairlist = PairListManager(exchange, freqai_conf)
strategy.dp = DataProvider(freqai_conf, exchange, pairlist)
strategy.freqai_info = freqai_conf.get("freqai", {})
freqai = strategy.freqai
freqai.live = False
freqai.train_queue = freqai._set_train_queue()
assert log_has_re(
"Set fresh train queue from whitelist.",
caplog,
)
def test_get_required_data_timerange(mocker, freqai_conf):
time_range = get_required_data_timerange(freqai_conf)
assert (time_range.stopts - time_range.startts) == 177300
def test_download_all_data_for_training(mocker, freqai_conf, caplog, tmpdir):
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
exchange = get_patched_exchange(mocker, freqai_conf)
pairlist = PairListManager(exchange, freqai_conf)
strategy.dp = DataProvider(freqai_conf, exchange, pairlist)
freqai_conf['pairs'] = freqai_conf['exchange']['pair_whitelist']
freqai_conf['datadir'] = Path(tmpdir)
download_all_data_for_training(strategy.dp, freqai_conf)
assert log_has_re(
"Downloading",
caplog,
)