allow user to multithread jobs (advanced users only)

This commit is contained in:
robcaulk
2022-09-10 22:16:49 +02:00
parent b3fc1cfde9
commit 5a0cfee27e
4 changed files with 9 additions and 13 deletions

View File

@@ -60,6 +60,9 @@ class CatboostRegressorMultiTarget(BaseRegressionModel):
{'eval_set': eval_sets[i], 'init_model': init_models[i]})
model = FreqaiMultiOutputRegressor(estimator=cbr)
thread_training = self.freqai_info.get('multitarget_parallel_training', False)
if thread_training:
model.n_jobs = y.shape[1]
model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params)
return model

View File

@@ -56,9 +56,9 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel):
'init_model': init_models[i]})
model = FreqaiMultiOutputRegressor(estimator=lgb)
thread_training = self.freqai_info.get('multitarget_parallel_training', False)
if thread_training:
model.n_jobs = y.shape[1]
model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params)
# model = FreqaiMultiOutputRegressor(estimator=lgb)
# model.fit(X=X, y=y, sample_weight=sample_weight, init_models=init_models,
# eval_sets=eval_sets, eval_sample_weight=eval_weights)
return model

View File

@@ -55,6 +55,9 @@ class XGBoostRegressorMultiTarget(BaseRegressionModel):
'xgb_model': init_models[i]})
model = FreqaiMultiOutputRegressor(estimator=xgb)
thread_training = self.freqai_info.get('multitarget_parallel_training', False)
if thread_training:
model.n_jobs = y.shape[1]
model.fit(X=X, y=y, sample_weight=sample_weight, fit_params=fit_params)
return model