Merge pull request #7392 from freqtrade/improve_ai_tests
Improve freqai tests by utilizing parametrization
This commit is contained in:
commit
075748b21a
@ -36,9 +36,6 @@ class FreqaiMultiOutputRegressor(MultiOutputRegressor):
|
|||||||
|
|
||||||
y = self._validate_data(X="no_validation", y=y, multi_output=True)
|
y = self._validate_data(X="no_validation", y=y, multi_output=True)
|
||||||
|
|
||||||
# if is_classifier(self):
|
|
||||||
# check_classification_targets(y)
|
|
||||||
|
|
||||||
if y.ndim == 1:
|
if y.ndim == 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"y must have at least two dimensions for "
|
"y must have at least two dimensions for "
|
||||||
@ -50,19 +47,12 @@ class FreqaiMultiOutputRegressor(MultiOutputRegressor):
|
|||||||
):
|
):
|
||||||
raise ValueError("Underlying estimator does not support sample weights.")
|
raise ValueError("Underlying estimator does not support sample weights.")
|
||||||
|
|
||||||
# fit_params_validated = _check_fit_params(X, fit_params)
|
|
||||||
|
|
||||||
if not fit_params:
|
if not fit_params:
|
||||||
fit_params = [None] * y.shape[1]
|
fit_params = [None] * y.shape[1]
|
||||||
|
|
||||||
# if not init_models:
|
|
||||||
# init_models = [None] * y.shape[1]
|
|
||||||
|
|
||||||
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
|
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
|
||||||
delayed(_fit_estimator)(
|
delayed(_fit_estimator)(
|
||||||
self.estimator, X, y[:, i], sample_weight, **fit_params[i]
|
self.estimator, X, y[:, i], sample_weight, **fit_params[i]
|
||||||
# init_model=init_models[i], eval_set=eval_sets[i],
|
|
||||||
# **fit_params_validated
|
|
||||||
)
|
)
|
||||||
for i in range(y.shape[1])
|
for i in range(y.shape[1])
|
||||||
)
|
)
|
||||||
|
@ -60,6 +60,9 @@ class CatboostRegressorMultiTarget(BaseRegressionModel):
|
|||||||
{'eval_set': eval_sets[i], 'init_model': init_models[i]})
|
{'eval_set': eval_sets[i], 'init_model': init_models[i]})
|
||||||
|
|
||||||
model = FreqaiMultiOutputRegressor(estimator=cbr)
|
model = FreqaiMultiOutputRegressor(estimator=cbr)
|
||||||
|
thread_training = self.freqai_info.get('multitarget_parallel_training', False)
|
||||||
|
if thread_training:
|
||||||
|
model.n_jobs = y.shape[1]
|
||||||
model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params)
|
model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -56,9 +56,9 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel):
|
|||||||
'init_model': init_models[i]})
|
'init_model': init_models[i]})
|
||||||
|
|
||||||
model = FreqaiMultiOutputRegressor(estimator=lgb)
|
model = FreqaiMultiOutputRegressor(estimator=lgb)
|
||||||
|
thread_training = self.freqai_info.get('multitarget_parallel_training', False)
|
||||||
|
if thread_training:
|
||||||
|
model.n_jobs = y.shape[1]
|
||||||
model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params)
|
model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params)
|
||||||
|
|
||||||
# model = FreqaiMultiOutputRegressor(estimator=lgb)
|
|
||||||
# model.fit(X=X, y=y, sample_weight=sample_weight, init_models=init_models,
|
|
||||||
# eval_sets=eval_sets, eval_sample_weight=eval_weights)
|
|
||||||
return model
|
return model
|
||||||
|
@ -55,6 +55,9 @@ class XGBoostRegressorMultiTarget(BaseRegressionModel):
|
|||||||
'xgb_model': init_models[i]})
|
'xgb_model': init_models[i]})
|
||||||
|
|
||||||
model = FreqaiMultiOutputRegressor(estimator=xgb)
|
model = FreqaiMultiOutputRegressor(estimator=xgb)
|
||||||
|
thread_training = self.freqai_info.get('multitarget_parallel_training', False)
|
||||||
|
if thread_training:
|
||||||
|
model.n_jobs = y.shape[1]
|
||||||
model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params)
|
model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -17,166 +17,17 @@ def is_arm() -> bool:
|
|||||||
return "arm" in machine or "aarch64" in machine
|
return "arm" in machine or "aarch64" in machine
|
||||||
|
|
||||||
|
|
||||||
def test_extract_data_and_train_model_LightGBM(mocker, freqai_conf):
|
@pytest.mark.parametrize('model', [
|
||||||
|
'LightGBMRegressor',
|
||||||
|
'XGBoostRegressor',
|
||||||
|
'CatboostRegressor',
|
||||||
|
])
|
||||||
|
def test_extract_data_and_train_model_Regressors(mocker, freqai_conf, model):
|
||||||
|
if is_arm() and model == 'CatboostRegressor':
|
||||||
|
pytest.skip("CatBoost is not supported on ARM")
|
||||||
|
|
||||||
|
freqai_conf.update({"freqaimodel": model})
|
||||||
freqai_conf.update({"timerange": "20180110-20180130"})
|
freqai_conf.update({"timerange": "20180110-20180130"})
|
||||||
|
|
||||||
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
|
||||||
exchange = get_patched_exchange(mocker, freqai_conf)
|
|
||||||
strategy.dp = DataProvider(freqai_conf, exchange)
|
|
||||||
strategy.freqai_info = freqai_conf.get("freqai", {})
|
|
||||||
freqai = strategy.freqai
|
|
||||||
freqai.live = True
|
|
||||||
freqai.dk = FreqaiDataKitchen(freqai_conf)
|
|
||||||
timerange = TimeRange.parse_timerange("20180110-20180130")
|
|
||||||
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
|
|
||||||
|
|
||||||
freqai.dd.pair_dict = MagicMock()
|
|
||||||
|
|
||||||
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
|
||||||
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
|
||||||
|
|
||||||
freqai.extract_data_and_train_model(
|
|
||||||
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)
|
|
||||||
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").is_file()
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").is_file()
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").is_file()
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").is_file()
|
|
||||||
|
|
||||||
shutil.rmtree(Path(freqai.dk.full_path))
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_data_and_train_model_LightGBMMultiModel(mocker, freqai_conf):
|
|
||||||
freqai_conf.update({"timerange": "20180110-20180130"})
|
|
||||||
freqai_conf.update({"strategy": "freqai_test_multimodel_strat"})
|
|
||||||
freqai_conf.update({"freqaimodel": "LightGBMRegressorMultiTarget"})
|
|
||||||
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
|
||||||
exchange = get_patched_exchange(mocker, freqai_conf)
|
|
||||||
strategy.dp = DataProvider(freqai_conf, exchange)
|
|
||||||
strategy.freqai_info = freqai_conf.get("freqai", {})
|
|
||||||
freqai = strategy.freqai
|
|
||||||
freqai.live = True
|
|
||||||
freqai.dk = FreqaiDataKitchen(freqai_conf)
|
|
||||||
timerange = TimeRange.parse_timerange("20180110-20180130")
|
|
||||||
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
|
|
||||||
|
|
||||||
freqai.dd.pair_dict = MagicMock()
|
|
||||||
|
|
||||||
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
|
||||||
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
|
||||||
|
|
||||||
freqai.extract_data_and_train_model(
|
|
||||||
new_timerange, "ADA/BTC", strategy, freqai.dk, data_load_timerange)
|
|
||||||
|
|
||||||
assert len(freqai.dk.label_list) == 2
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").is_file()
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").is_file()
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").is_file()
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").is_file()
|
|
||||||
assert len(freqai.dk.data['training_features_list']) == 26
|
|
||||||
|
|
||||||
shutil.rmtree(Path(freqai.dk.full_path))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(is_arm(), reason="no ARM for Catboost ...")
|
|
||||||
def test_extract_data_and_train_model_Catboost(mocker, freqai_conf):
|
|
||||||
freqai_conf.update({"timerange": "20180110-20180130"})
|
|
||||||
freqai_conf.update({"freqaimodel": "CatboostRegressor"})
|
|
||||||
# freqai_conf.get('freqai', {}).update(
|
|
||||||
# {'model_training_parameters': {"n_estimators": 100, "verbose": 0}})
|
|
||||||
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
|
||||||
exchange = get_patched_exchange(mocker, freqai_conf)
|
|
||||||
strategy.dp = DataProvider(freqai_conf, exchange)
|
|
||||||
|
|
||||||
strategy.freqai_info = freqai_conf.get("freqai", {})
|
|
||||||
freqai = strategy.freqai
|
|
||||||
freqai.live = True
|
|
||||||
freqai.dk = FreqaiDataKitchen(freqai_conf)
|
|
||||||
timerange = TimeRange.parse_timerange("20180110-20180130")
|
|
||||||
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
|
|
||||||
|
|
||||||
freqai.dd.pair_dict = MagicMock()
|
|
||||||
|
|
||||||
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
|
||||||
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
|
||||||
|
|
||||||
freqai.extract_data_and_train_model(new_timerange, "ADA/BTC",
|
|
||||||
strategy, freqai.dk, data_load_timerange)
|
|
||||||
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").exists()
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").exists()
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").exists()
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").exists()
|
|
||||||
|
|
||||||
shutil.rmtree(Path(freqai.dk.full_path))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(is_arm(), reason="no ARM for Catboost ...")
|
|
||||||
def test_extract_data_and_train_model_CatboostClassifier(mocker, freqai_conf):
|
|
||||||
freqai_conf.update({"timerange": "20180110-20180130"})
|
|
||||||
freqai_conf.update({"freqaimodel": "CatboostClassifier"})
|
|
||||||
freqai_conf.update({"strategy": "freqai_test_classifier"})
|
|
||||||
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
|
||||||
exchange = get_patched_exchange(mocker, freqai_conf)
|
|
||||||
strategy.dp = DataProvider(freqai_conf, exchange)
|
|
||||||
|
|
||||||
strategy.freqai_info = freqai_conf.get("freqai", {})
|
|
||||||
freqai = strategy.freqai
|
|
||||||
freqai.live = True
|
|
||||||
freqai.dk = FreqaiDataKitchen(freqai_conf)
|
|
||||||
timerange = TimeRange.parse_timerange("20180110-20180130")
|
|
||||||
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
|
|
||||||
|
|
||||||
freqai.dd.pair_dict = MagicMock()
|
|
||||||
|
|
||||||
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
|
||||||
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
|
||||||
|
|
||||||
freqai.extract_data_and_train_model(new_timerange, "ADA/BTC",
|
|
||||||
strategy, freqai.dk, data_load_timerange)
|
|
||||||
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").exists()
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").exists()
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").exists()
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").exists()
|
|
||||||
|
|
||||||
shutil.rmtree(Path(freqai.dk.full_path))
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_data_and_train_model_LightGBMClassifier(mocker, freqai_conf):
|
|
||||||
freqai_conf.update({"timerange": "20180110-20180130"})
|
|
||||||
freqai_conf.update({"freqaimodel": "LightGBMClassifier"})
|
|
||||||
freqai_conf.update({"strategy": "freqai_test_classifier"})
|
|
||||||
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
|
||||||
exchange = get_patched_exchange(mocker, freqai_conf)
|
|
||||||
strategy.dp = DataProvider(freqai_conf, exchange)
|
|
||||||
|
|
||||||
strategy.freqai_info = freqai_conf.get("freqai", {})
|
|
||||||
freqai = strategy.freqai
|
|
||||||
freqai.live = True
|
|
||||||
freqai.dk = FreqaiDataKitchen(freqai_conf)
|
|
||||||
timerange = TimeRange.parse_timerange("20180110-20180130")
|
|
||||||
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
|
|
||||||
|
|
||||||
freqai.dd.pair_dict = MagicMock()
|
|
||||||
|
|
||||||
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
|
||||||
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
|
||||||
|
|
||||||
freqai.extract_data_and_train_model(new_timerange, "ADA/BTC",
|
|
||||||
strategy, freqai.dk, data_load_timerange)
|
|
||||||
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").exists()
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").exists()
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").exists()
|
|
||||||
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").exists()
|
|
||||||
|
|
||||||
shutil.rmtree(Path(freqai.dk.full_path))
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_data_and_train_model_XGBoostRegressor(mocker, freqai_conf):
|
|
||||||
freqai_conf.update({"timerange": "20180110-20180130"})
|
|
||||||
freqai_conf.update({"freqaimodel": "XGBoostRegressor"})
|
|
||||||
freqai_conf.update({"strategy": "freqai_test_strat"})
|
freqai_conf.update({"strategy": "freqai_test_strat"})
|
||||||
|
|
||||||
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
||||||
@ -205,10 +56,18 @@ def test_extract_data_and_train_model_XGBoostRegressor(mocker, freqai_conf):
|
|||||||
shutil.rmtree(Path(freqai.dk.full_path))
|
shutil.rmtree(Path(freqai.dk.full_path))
|
||||||
|
|
||||||
|
|
||||||
def test_extract_data_and_train_model_XGBoostRegressorMultiModel(mocker, freqai_conf):
|
@pytest.mark.parametrize('model', [
|
||||||
|
'LightGBMRegressorMultiTarget',
|
||||||
|
'XGBoostRegressorMultiTarget',
|
||||||
|
'CatboostRegressorMultiTarget',
|
||||||
|
])
|
||||||
|
def test_extract_data_and_train_model_MultiTargets(mocker, freqai_conf, model):
|
||||||
|
if is_arm() and model == 'CatboostRegressorMultiTarget':
|
||||||
|
pytest.skip("CatBoost is not supported on ARM")
|
||||||
|
|
||||||
freqai_conf.update({"timerange": "20180110-20180130"})
|
freqai_conf.update({"timerange": "20180110-20180130"})
|
||||||
freqai_conf.update({"freqaimodel": "XGBoostRegressorMultiTarget"})
|
|
||||||
freqai_conf.update({"strategy": "freqai_test_multimodel_strat"})
|
freqai_conf.update({"strategy": "freqai_test_multimodel_strat"})
|
||||||
|
freqai_conf.update({"freqaimodel": model})
|
||||||
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
||||||
exchange = get_patched_exchange(mocker, freqai_conf)
|
exchange = get_patched_exchange(mocker, freqai_conf)
|
||||||
strategy.dp = DataProvider(freqai_conf, exchange)
|
strategy.dp = DataProvider(freqai_conf, exchange)
|
||||||
@ -237,6 +96,44 @@ def test_extract_data_and_train_model_XGBoostRegressorMultiModel(mocker, freqai_
|
|||||||
shutil.rmtree(Path(freqai.dk.full_path))
|
shutil.rmtree(Path(freqai.dk.full_path))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('model', [
|
||||||
|
'LightGBMClassifier',
|
||||||
|
'CatboostClassifier',
|
||||||
|
])
|
||||||
|
def test_extract_data_and_train_model_Classifiers(mocker, freqai_conf, model):
|
||||||
|
if is_arm() and model == 'CatboostClassifier':
|
||||||
|
pytest.skip("CatBoost is not supported on ARM")
|
||||||
|
|
||||||
|
freqai_conf.update({"freqaimodel": model})
|
||||||
|
freqai_conf.update({"strategy": "freqai_test_classifier"})
|
||||||
|
freqai_conf.update({"timerange": "20180110-20180130"})
|
||||||
|
strategy = get_patched_freqai_strategy(mocker, freqai_conf)
|
||||||
|
exchange = get_patched_exchange(mocker, freqai_conf)
|
||||||
|
strategy.dp = DataProvider(freqai_conf, exchange)
|
||||||
|
|
||||||
|
strategy.freqai_info = freqai_conf.get("freqai", {})
|
||||||
|
freqai = strategy.freqai
|
||||||
|
freqai.live = True
|
||||||
|
freqai.dk = FreqaiDataKitchen(freqai_conf)
|
||||||
|
timerange = TimeRange.parse_timerange("20180110-20180130")
|
||||||
|
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
|
||||||
|
|
||||||
|
freqai.dd.pair_dict = MagicMock()
|
||||||
|
|
||||||
|
data_load_timerange = TimeRange.parse_timerange("20180110-20180130")
|
||||||
|
new_timerange = TimeRange.parse_timerange("20180120-20180130")
|
||||||
|
|
||||||
|
freqai.extract_data_and_train_model(new_timerange, "ADA/BTC",
|
||||||
|
strategy, freqai.dk, data_load_timerange)
|
||||||
|
|
||||||
|
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_model.joblib").exists()
|
||||||
|
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_metadata.json").exists()
|
||||||
|
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_trained_df.pkl").exists()
|
||||||
|
assert Path(freqai.dk.data_path / f"{freqai.dk.model_filename}_svm_model.joblib").exists()
|
||||||
|
|
||||||
|
shutil.rmtree(Path(freqai.dk.full_path))
|
||||||
|
|
||||||
|
|
||||||
def test_start_backtesting(mocker, freqai_conf):
|
def test_start_backtesting(mocker, freqai_conf):
|
||||||
freqai_conf.update({"timerange": "20180120-20180130"})
|
freqai_conf.update({"timerange": "20180120-20180130"})
|
||||||
freqai_conf.get("freqai", {}).update({"save_backtest_models": True})
|
freqai_conf.get("freqai", {}).update({"save_backtest_models": True})
|
||||||
|
Loading…
Reference in New Issue
Block a user