enable continual learning and evaluation sets on multioutput models.

This commit is contained in:
robcaulk
2022-09-10 16:54:13 +02:00
parent 170bec0438
commit 10b6aebc5f
12 changed files with 170 additions and 38 deletions

View File

@@ -3,8 +3,8 @@ from typing import Any, Dict
from xgboost import XGBRegressor
from freqtrade.freqai.base_models.BaseRegressionModel import BaseRegressionModel
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.BaseRegressionModel import BaseRegressionModel
logger = logging.getLogger(__name__)
@@ -31,6 +31,7 @@ class XGBoostRegressor(BaseRegressionModel):
eval_set = None
else:
eval_set = [(data_dictionary["test_features"], data_dictionary["test_labels"])]
eval_weights = [data_dictionary['test_weights']]
sample_weight = data_dictionary["train_weights"]
@@ -38,6 +39,7 @@ class XGBoostRegressor(BaseRegressionModel):
model = XGBRegressor(**self.model_training_parameters)
model.fit(X=X, y=y, sample_weight=sample_weight, eval_set=eval_set, xgb_model=xgb_model)
model.fit(X=X, y=y, sample_weight=sample_weight, eval_set=eval_set,
sample_weight_eval_set=eval_weights, xgb_model=xgb_model)
return model