stable/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py

89 lines
3.3 KiB
Python
Raw Normal View History

from typing import Any, Dict
2023-03-19 13:09:50 +00:00
import torch
2023-03-22 15:50:00 +00:00
from freqtrade.freqai.base_models.BasePyTorchClassifier import BasePyTorchClassifier
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
2023-04-03 13:03:15 +00:00
from freqtrade.freqai.torch.PyTorchDataConvertor import (DefaultPyTorchDataConvertor,
PyTorchDataConvertor)
2023-03-21 14:09:54 +00:00
from freqtrade.freqai.torch.PyTorchMLPModel import PyTorchMLPModel
from freqtrade.freqai.torch.PyTorchModelTrainer import PyTorchModelTrainer
2023-03-22 15:50:00 +00:00
class PyTorchMLPClassifier(BasePyTorchClassifier):
"""
This class implements the fit method of IFreqaiModel.
2023-03-20 15:06:33 +00:00
in the fit method we initialize the model and trainer objects.
the only requirement from the model is to be aligned to PyTorchClassifier
2023-03-20 18:22:28 +00:00
predict method that expects the model to predict a tensor of type long.
parameters are passed via `model_training_parameters` under the freqai
section in the config file. e.g:
{
...
"freqai": {
...
"model_training_parameters" : {
"learning_rate": 3e-4,
"trainer_kwargs": {
"max_iters": 5000,
"batch_size": 64,
"max_n_eval_batches": null,
},
"model_kwargs": {
"hidden_dim": 512,
"dropout_percent": 0.2,
"n_layer": 1,
},
}
}
}
"""
2023-04-03 12:19:10 +00:00
@property
def data_convertor(self) -> PyTorchDataConvertor:
2023-04-03 13:03:15 +00:00
return DefaultPyTorchDataConvertor(
target_tensor_type=torch.long,
squeeze_target_tensor=True
)
2023-04-03 12:19:10 +00:00
2023-03-21 13:19:34 +00:00
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
config = self.freqai_info.get("model_training_parameters", {})
2023-03-21 09:23:45 +00:00
self.learning_rate: float = config.get("learning_rate", 3e-4)
self.model_kwargs: Dict[str, Any] = config.get("model_kwargs", {})
self.trainer_kwargs: Dict[str, Any] = config.get("trainer_kwargs", {})
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
"""
User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold
all the training and test data/labels.
:raises ValueError: If self.class_names is not defined in the parent class.
"""
class_names = self.get_class_names()
self.convert_label_column_to_int(data_dictionary, dk, class_names)
n_features = data_dictionary["train_features"].shape[-1]
model = PyTorchMLPModel(
input_dim=n_features,
output_dim=len(class_names),
**self.model_kwargs
)
model.to(self.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
criterion = torch.nn.CrossEntropyLoss()
init_model = self.get_init_model(dk.pair)
trainer = PyTorchModelTrainer(
model=model,
optimizer=optimizer,
criterion=criterion,
model_meta_data={"class_names": class_names},
device=self.device,
init_model=init_model,
2023-04-03 12:19:10 +00:00
data_convertor=self.data_convertor,
**self.trainer_kwargs,
)
2023-03-28 11:40:23 +00:00
trainer.fit(data_dictionary, self.splits)
return trainer