improve pytorch classifier documentation

This commit is contained in:
Yinon Polak 2023-03-20 18:39:50 +02:00
parent 81a2cbb4eb
commit 500c401b75

View File

@ -20,6 +20,18 @@ class PyTorchClassifier(BasePyTorchModel):
"""
A PyTorch implementation of a classifier.
User must implement fit method
Important!
User must declare the target class names in the strategy, under
IStrategy.set_freqai_targets method.
```
def set_freqai_targets(self, dataframe: DataFrame, metadata: Dict, **kwargs):
self.freqai.class_names = ["down", "up"]
dataframe['&s-up_or_down'] = np.where(dataframe["close"].shift(-100) >
dataframe["close"], 'up', 'down')
return dataframe
```
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -127,7 +139,7 @@ class PyTorchClassifier(BasePyTorchModel):
if not hasattr(self, "class_names"):
raise ValueError(
"Missing attribute: self.class_names "
"set self.freqai.class_names = [\"class a\", \"class b\", \"class c\"] "
"set self.freqai.class_names = ['class a', 'class b', 'class c'] "
"inside IStrategy.set_freqai_targets method."
)
return self.class_names