allow user to multithread jobs (advanced users only)

This commit is contained in:
robcaulk 2022-09-10 22:16:49 +02:00
parent b3fc1cfde9
commit 5a0cfee27e
4 changed files with 9 additions and 13 deletions

View File

@ -36,9 +36,6 @@ class FreqaiMultiOutputRegressor(MultiOutputRegressor):
y = self._validate_data(X="no_validation", y=y, multi_output=True)
# if is_classifier(self):
# check_classification_targets(y)
if y.ndim == 1:
raise ValueError(
"y must have at least two dimensions for "
@ -50,19 +47,12 @@ class FreqaiMultiOutputRegressor(MultiOutputRegressor):
):
raise ValueError("Underlying estimator does not support sample weights.")
# fit_params_validated = _check_fit_params(X, fit_params)
if not fit_params:
fit_params = [None] * y.shape[1]
# if not init_models:
# init_models = [None] * y.shape[1]
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
delayed(_fit_estimator)(
self.estimator, X, y[:, i], sample_weight, **fit_params[i]
# init_model=init_models[i], eval_set=eval_sets[i],
# **fit_params_validated
)
for i in range(y.shape[1])
)

View File

@ -60,6 +60,9 @@ class CatboostRegressorMultiTarget(BaseRegressionModel):
{'eval_set': eval_sets[i], 'init_model': init_models[i]})
model = FreqaiMultiOutputRegressor(estimator=cbr)
thread_training = self.freqai_info.get('multitarget_parallel_training', False)
if thread_training:
model.n_jobs = y.shape[1]
model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params)
return model

View File

@ -56,9 +56,9 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel):
'init_model': init_models[i]})
model = FreqaiMultiOutputRegressor(estimator=lgb)
thread_training = self.freqai_info.get('multitarget_parallel_training', False)
if thread_training:
model.n_jobs = y.shape[1]
model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params)
# model = FreqaiMultiOutputRegressor(estimator=lgb)
# model.fit(X=X, y=y, sample_weight=sample_weight, init_models=init_models,
# eval_sets=eval_sets, eval_sample_weight=eval_weights)
return model

View File

@ -55,6 +55,9 @@ class XGBoostRegressorMultiTarget(BaseRegressionModel):
'xgb_model': init_models[i]})
model = FreqaiMultiOutputRegressor(estimator=xgb)
thread_training = self.freqai_info.get('multitarget_parallel_training', False)
if thread_training:
model.n_jobs = y.shape[1]
model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params)
return model