simplified predict and predict_proba using super(). Added duplicate class label check.
This commit is contained in:
parent
6ef82dd8b6
commit
7053f81fa8
@ -4,7 +4,9 @@ from sklearn.base import is_classifier
|
|||||||
from sklearn.multioutput import MultiOutputClassifier, _fit_estimator
|
from sklearn.multioutput import MultiOutputClassifier, _fit_estimator
|
||||||
from sklearn.utils.fixes import delayed
|
from sklearn.utils.fixes import delayed
|
||||||
from sklearn.utils.multiclass import check_classification_targets
|
from sklearn.utils.multiclass import check_classification_targets
|
||||||
from sklearn.utils.validation import check_is_fitted, has_fit_parameter
|
from sklearn.utils.validation import has_fit_parameter
|
||||||
|
|
||||||
|
from freqtrade.exceptions import OperationalException
|
||||||
|
|
||||||
|
|
||||||
class FreqaiMultiOutputClassifier(MultiOutputClassifier):
|
class FreqaiMultiOutputClassifier(MultiOutputClassifier):
|
||||||
@ -65,6 +67,9 @@ class FreqaiMultiOutputClassifier(MultiOutputClassifier):
|
|||||||
self.classes_ = []
|
self.classes_ = []
|
||||||
for estimator in self.estimators_:
|
for estimator in self.estimators_:
|
||||||
self.classes_.extend(estimator.classes_)
|
self.classes_.extend(estimator.classes_)
|
||||||
|
if len(set(self.classes_)) != len(self.classes_):
|
||||||
|
raise OperationalException(f"Class labels must be unique across targets: "
|
||||||
|
f"{self.classes_}")
|
||||||
|
|
||||||
if hasattr(self.estimators_[0], "n_features_in_"):
|
if hasattr(self.estimators_[0], "n_features_in_"):
|
||||||
self.n_features_in_ = self.estimators_[0].n_features_in_
|
self.n_features_in_ = self.estimators_[0].n_features_in_
|
||||||
@ -74,56 +79,15 @@ class FreqaiMultiOutputClassifier(MultiOutputClassifier):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def predict_proba(self, X):
|
def predict_proba(self, X):
|
||||||
"""Return prediction probabilities for each class of each output.
|
|
||||||
|
|
||||||
This method will raise a ``ValueError`` if any of the
|
|
||||||
estimators do not have ``predict_proba``.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
X : array-like of shape (n_samples, n_features)
|
|
||||||
The input data.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
p : array of shape (n_samples, n_classes), or a list of n_outputs \
|
|
||||||
such arrays if n_outputs > 1.
|
|
||||||
The class probabilities of the input samples. The order of the
|
|
||||||
classes corresponds to that in the attribute :term:`classes_`.
|
|
||||||
|
|
||||||
.. versionchanged:: 0.19
|
|
||||||
This function now returns a list of arrays where the length of
|
|
||||||
the list is ``n_outputs``, and each array is (``n_samples``,
|
|
||||||
``n_classes``) for that particular output.
|
|
||||||
"""
|
"""
|
||||||
check_is_fitted(self)
|
Get predict_proba and stack arrays horizontally
|
||||||
results = np.squeeze(np.hstack(
|
"""
|
||||||
[estimator.predict_proba(X) for estimator in self.estimators_]
|
results = np.hstack(super().predict_proba(X))
|
||||||
))
|
return np.squeeze(results)
|
||||||
return results
|
|
||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
"""Predict multi-output variable using model for each target variable.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
||||||
The input data.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
y : {array-like, sparse matrix} of shape (n_samples, n_outputs)
|
|
||||||
Multi-output targets predicted across multiple predictors.
|
|
||||||
Note: Separate models are generated for each predictor.
|
|
||||||
"""
|
"""
|
||||||
check_is_fitted(self)
|
Get predict and squeeze into 2D array
|
||||||
if not hasattr(self.estimators_[0], "predict"):
|
"""
|
||||||
raise ValueError("The base estimator should implement a predict method")
|
results = super().predict(X)
|
||||||
|
return np.squeeze(results)
|
||||||
y = Parallel(n_jobs=self.n_jobs)(
|
|
||||||
delayed(e.predict)(X) for e in self.estimators_
|
|
||||||
)
|
|
||||||
|
|
||||||
results = np.squeeze(np.asarray(y).T)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
Loading…
Reference in New Issue
Block a user