diff --git a/freqtrade/freqai/base_models/FreqaiMultiOutputRegressor.py b/freqtrade/freqai/base_models/FreqaiMultiOutputRegressor.py index aa5dbe629..a9db81e31 100644 --- a/freqtrade/freqai/base_models/FreqaiMultiOutputRegressor.py +++ b/freqtrade/freqai/base_models/FreqaiMultiOutputRegressor.py @@ -36,9 +36,6 @@ class FreqaiMultiOutputRegressor(MultiOutputRegressor): y = self._validate_data(X="no_validation", y=y, multi_output=True) - # if is_classifier(self): - # check_classification_targets(y) - if y.ndim == 1: raise ValueError( "y must have at least two dimensions for " @@ -50,19 +47,12 @@ class FreqaiMultiOutputRegressor(MultiOutputRegressor): ): raise ValueError("Underlying estimator does not support sample weights.") - # fit_params_validated = _check_fit_params(X, fit_params) - if not fit_params: fit_params = [None] * y.shape[1] - # if not init_models: - # init_models = [None] * y.shape[1] - self.estimators_ = Parallel(n_jobs=self.n_jobs)( delayed(_fit_estimator)( self.estimator, X, y[:, i], sample_weight, **fit_params[i] - # init_model=init_models[i], eval_set=eval_sets[i], - # **fit_params_validated ) for i in range(y.shape[1]) ) diff --git a/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py b/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py index a376b2c33..7fa4e293e 100644 --- a/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py +++ b/freqtrade/freqai/prediction_models/CatboostRegressorMultiTarget.py @@ -60,6 +60,9 @@ class CatboostRegressorMultiTarget(BaseRegressionModel): {'eval_set': eval_sets[i], 'init_model': init_models[i]}) model = FreqaiMultiOutputRegressor(estimator=cbr) + thread_training = self.freqai_info.get('multitarget_parallel_training', False) + if thread_training: + model.n_jobs = y.shape[1] model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params) return model diff --git a/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py b/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py index 7a9b5c36a..37c6bb186 100644 --- a/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py +++ b/freqtrade/freqai/prediction_models/LightGBMRegressorMultiTarget.py @@ -56,9 +56,9 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel): 'init_model': init_models[i]}) model = FreqaiMultiOutputRegressor(estimator=lgb) + thread_training = self.freqai_info.get('multitarget_parallel_training', False) + if thread_training: + model.n_jobs = y.shape[1] model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params) - # model = FreqaiMultiOutputRegressor(estimator=lgb) - # model.fit(X=X, y=y, sample_weight=sample_weight, init_models=init_models, - # eval_sets=eval_sets, eval_sample_weight=eval_weights) return model diff --git a/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py b/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py index 38c478c0b..920745ec9 100644 --- a/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py +++ b/freqtrade/freqai/prediction_models/XGBoostRegressorMultiTarget.py @@ -55,6 +55,9 @@ class XGBoostRegressorMultiTarget(BaseRegressionModel): 'xgb_model': init_models[i]}) model = FreqaiMultiOutputRegressor(estimator=xgb) + thread_training = self.freqai_info.get('multitarget_parallel_training', False) + if thread_training: + model.n_jobs = y.shape[1] model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params) return model