improve pytorch classifier documentation
This commit is contained in:
parent
81a2cbb4eb
commit
500c401b75
@ -20,6 +20,18 @@ class PyTorchClassifier(BasePyTorchModel):
|
|||||||
"""
|
"""
|
||||||
A PyTorch implementation of a classifier.
|
A PyTorch implementation of a classifier.
|
||||||
User must implement fit method
|
User must implement fit method
|
||||||
|
|
||||||
|
Important!
|
||||||
|
User must declare the target class names in the strategy, under
|
||||||
|
IStrategy.set_freqai_targets method.
|
||||||
|
```
|
||||||
|
def set_freqai_targets(self, dataframe: DataFrame, metadata: Dict, **kwargs):
|
||||||
|
self.freqai.class_names = ["down", "up"]
|
||||||
|
dataframe['&s-up_or_down'] = np.where(dataframe["close"].shift(-100) >
|
||||||
|
dataframe["close"], 'up', 'down')
|
||||||
|
|
||||||
|
return dataframe
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -127,7 +139,7 @@ class PyTorchClassifier(BasePyTorchModel):
|
|||||||
if not hasattr(self, "class_names"):
|
if not hasattr(self, "class_names"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Missing attribute: self.class_names "
|
"Missing attribute: self.class_names "
|
||||||
"set self.freqai.class_names = [\"class a\", \"class b\", \"class c\"] "
|
"set self.freqai.class_names = ['class a', 'class b', 'class c'] "
|
||||||
"inside IStrategy.set_freqai_targets method."
|
"inside IStrategy.set_freqai_targets method."
|
||||||
)
|
)
|
||||||
return self.class_names
|
return self.class_names
|
||||||
|
Loading…
Reference in New Issue
Block a user