improve pytorch classifier documentation
This commit is contained in:
parent
81a2cbb4eb
commit
500c401b75
@ -20,6 +20,18 @@ class PyTorchClassifier(BasePyTorchModel):
|
||||
"""
|
||||
A PyTorch implementation of a classifier.
|
||||
User must implement fit method
|
||||
|
||||
Important!
|
||||
User must declare the target class names in the strategy, under
|
||||
IStrategy.set_freqai_targets method.
|
||||
```
|
||||
def set_freqai_targets(self, dataframe: DataFrame, metadata: Dict, **kwargs):
|
||||
self.freqai.class_names = ["down", "up"]
|
||||
dataframe['&s-up_or_down'] = np.where(dataframe["close"].shift(-100) >
|
||||
dataframe["close"], 'up', 'down')
|
||||
|
||||
return dataframe
|
||||
```
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@ -127,7 +139,7 @@ class PyTorchClassifier(BasePyTorchModel):
|
||||
if not hasattr(self, "class_names"):
|
||||
raise ValueError(
|
||||
"Missing attribute: self.class_names "
|
||||
"set self.freqai.class_names = [\"class a\", \"class b\", \"class c\"] "
|
||||
"set self.freqai.class_names = ['class a', 'class b', 'class c'] "
|
||||
"inside IStrategy.set_freqai_targets method."
|
||||
)
|
||||
return self.class_names
|
||||
|
Loading…
Reference in New Issue
Block a user