multi target classifier working but not for parallel
This commit is contained in:
parent
283dab667d
commit
47056eded3
64
freqtrade/freqai/base_models/FreqaiMultiOutputClassifier.py
Normal file
64
freqtrade/freqai/base_models/FreqaiMultiOutputClassifier.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from joblib import Parallel
|
||||||
|
from sklearn.multioutput import MultiOutputRegressor, _fit_estimator
|
||||||
|
from sklearn.utils.fixes import delayed
|
||||||
|
from sklearn.utils.validation import has_fit_parameter
|
||||||
|
|
||||||
|
|
||||||
|
class FreqaiMultiOutputRegressor(MultiOutputRegressor):
|
||||||
|
|
||||||
|
def fit(self, X, y, sample_weight=None, fit_params=None):
|
||||||
|
"""Fit the model to data, separately for each output variable.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||||
|
The input data.
|
||||||
|
y : {array-like, sparse matrix} of shape (n_samples, n_outputs)
|
||||||
|
Multi-output targets. An indicator matrix turns on multilabel
|
||||||
|
estimation.
|
||||||
|
sample_weight : array-like of shape (n_samples,), default=None
|
||||||
|
Sample weights. If `None`, then samples are equally weighted.
|
||||||
|
Only supported if the underlying regressor supports sample
|
||||||
|
weights.
|
||||||
|
fit_params : A list of dicts for the fit_params
|
||||||
|
Parameters passed to the ``estimator.fit`` method of each step.
|
||||||
|
Each dict may contain same or different values (e.g. different
|
||||||
|
eval_sets or init_models)
|
||||||
|
.. versionadded:: 0.23
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
self : object
|
||||||
|
Returns a fitted instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not hasattr(self.estimator, "fit"):
|
||||||
|
raise ValueError("The base estimator should implement a fit method")
|
||||||
|
|
||||||
|
y = self._validate_data(X="no_validation", y=y, multi_output=True)
|
||||||
|
|
||||||
|
if y.ndim == 1:
|
||||||
|
raise ValueError(
|
||||||
|
"y must have at least two dimensions for "
|
||||||
|
"multi-output regression but has only one."
|
||||||
|
)
|
||||||
|
|
||||||
|
if sample_weight is not None and not has_fit_parameter(
|
||||||
|
self.estimator, "sample_weight"
|
||||||
|
):
|
||||||
|
raise ValueError("Underlying estimator does not support sample weights.")
|
||||||
|
|
||||||
|
if not fit_params:
|
||||||
|
fit_params = [None] * y.shape[1]
|
||||||
|
|
||||||
|
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
|
||||||
|
delayed(_fit_estimator)(
|
||||||
|
self.estimator, X, y[:, i], sample_weight, **fit_params[i]
|
||||||
|
)
|
||||||
|
for i in range(y.shape[1])
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(self.estimators_[0], "n_features_in_"):
|
||||||
|
self.n_features_in_ = self.estimators_[0].n_features_in_
|
||||||
|
if hasattr(self.estimators_[0], "feature_names_in_"):
|
||||||
|
self.feature_names_in_ = self.estimators_[0].feature_names_in_
|
||||||
|
|
||||||
|
return
|
@ -0,0 +1,55 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from catboost import CatBoostClassifier, Pool
|
||||||
|
|
||||||
|
from freqtrade.freqai.base_models.BaseClassifierModel import BaseClassifierModel
|
||||||
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CatboostClassifier(BaseClassifierModel):
|
||||||
|
"""
|
||||||
|
User created prediction model. The class needs to override three necessary
|
||||||
|
functions, predict(), train(), fit(). The class inherits ModelHandler which
|
||||||
|
has its own DataHandler where data is held, saved, loaded, and managed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
User sets up the training and test data to fit their desired model here
|
||||||
|
:param data_dictionary: the dictionary constructed by DataHandler to hold
|
||||||
|
all the training and test data/labels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
train_data = Pool(
|
||||||
|
data=data_dictionary["train_features"],
|
||||||
|
label=data_dictionary["train_labels"],
|
||||||
|
weight=data_dictionary["train_weights"],
|
||||||
|
)
|
||||||
|
if self.freqai_info.get("data_split_parameters", {}).get("test_size", 0.1) == 0:
|
||||||
|
test_data = None
|
||||||
|
else:
|
||||||
|
test_data = Pool(
|
||||||
|
data=data_dictionary["test_features"],
|
||||||
|
label=data_dictionary["test_labels"],
|
||||||
|
weight=data_dictionary["test_weights"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cbr = CatBoostClassifier(
|
||||||
|
allow_writing_files=True,
|
||||||
|
loss_function='MultiClass',
|
||||||
|
train_dir=Path(dk.data_path),
|
||||||
|
**self.model_training_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
init_model = self.get_init_model(dk.pair)
|
||||||
|
|
||||||
|
cbr.fit(X=train_data, eval_set=test_data, init_model=init_model,
|
||||||
|
log_cout=sys.stdout, log_cerr=sys.stderr)
|
||||||
|
|
||||||
|
return cbr
|
Loading…
Reference in New Issue
Block a user