create children class to PyTorchClassifier to implement the fit method where we initialize the trainer and model objects

This commit is contained in:
Yinon Polak 2023-03-19 14:38:49 +02:00
parent a49f62eecb
commit 366c148c10
4 changed files with 146 additions and 87 deletions

View File

@ -19,35 +19,32 @@ class PyTorchModelTrainer:
optimizer: Optimizer, optimizer: Optimizer,
criterion: nn.Module, criterion: nn.Module,
device: str, device: str,
batch_size: int,
max_iters: int,
max_n_eval_batches: int,
init_model: Dict, init_model: Dict,
model_meta_data: Dict[str, Any] = {}, model_meta_data: Dict[str, Any] = {},
**kwargs
): ):
""" """
:param model: The PyTorch model to be trained. :param model: The PyTorch model to be trained.
:param optimizer: The optimizer to use for training. :param optimizer: The optimizer to use for training.
:param criterion: The loss function to use for training. :param criterion: The loss function to use for training.
:param device: The device to use for training (e.g. 'cpu', 'cuda'). :param device: The device to use for training (e.g. 'cpu', 'cuda').
:param batch_size: The size of the batches to use during training.
:param max_iters: The number of training iterations to run.
iteration here refers to the number of times we call
self.optimizer.step(). used to calculate n_epochs.
:param max_n_eval_batches: The maximum number batches to use for evaluation.
:param init_model: A dictionary containing the initial model/optimizer :param init_model: A dictionary containing the initial model/optimizer
state_dict and model_meta_data saved by self.save() method. state_dict and model_meta_data saved by self.save() method.
:param model_meta_data: Additional metadata about the model (optional). :param model_meta_data: Additional metadata about the model (optional).
:param max_iters: The number of training iterations to run.
iteration here refers to the number of times we call
self.optimizer.step(). used to calculate n_epochs.
:param batch_size: The size of the batches to use during training.
:param max_n_eval_batches: The maximum number batches to use for evaluation.
""" """
self.model = model self.model = model
self.optimizer = optimizer self.optimizer = optimizer
self.criterion = criterion self.criterion = criterion
self.model_meta_data = model_meta_data self.model_meta_data = model_meta_data
self.device = device self.device = device
self.max_iters = max_iters self.max_iters: int = kwargs.get("max_iters", 100)
self.batch_size = batch_size self.batch_size: int = kwargs.get("batch_size", 64)
self.max_n_eval_batches = max_n_eval_batches self.max_n_eval_batches: Optional[int] = kwargs.get("max_n_eval_batches", None)
if init_model: if init_model:
self.load_from_checkpoint(init_model) self.load_from_checkpoint(init_model)

View File

@ -0,0 +1,81 @@
from typing import Any, Dict
from freqtrade.freqai.base_models.PyTorchModelTrainer import PyTorchModelTrainer
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.PyTorchClassifierClassifier import PyTorchClassifier
from freqtrade.freqai.prediction_models.PyTorchMLPModel import PyTorchMLPModel
import torch
class MLPPyTorchClassifier(PyTorchClassifier):
"""
This class implements the fit method of IFreqaiModel.
int the fit method we initialize the model and trainer objects.
the only requirement from the model is to be aligned to PyTorchClassifier
predict method that expects the model to predict tensor of type long.
the trainer defines the training loop.
parameters are passed via `model_training_parameters` under the freqai
section in the config file. e.g:
{
...
"freqai": {
...
"model_training_parameters" : {
"learning_rate": 3e-4,
"trainer_kwargs": {
"max_iters": 5000,
"batch_size": 64,
"max_n_eval_batches": None,
},
"model_kwargs": {
"hidden_dim": 512,
"dropout_percent": 0.2,
"n_layer": 1,
},
}
}
}
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
model_training_params = self.freqai_info.get("model_training_parameters", {})
self.learning_rate: float = model_training_params.get("learning_rate", 3e-4)
self.model_kwargs: Dict[str, any] = model_training_params.get("model_kwargs", {})
self.trainer_kwargs: Dict[str, any] = model_training_params.get("trainer_kwargs", {})
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
"""
User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold
all the training and test data/labels.
:raises ValueError: If self.class_names is not defined in the parent class.
"""
class_names = self.get_class_names()
self.convert_label_column_to_int(data_dictionary, dk, class_names)
n_features = data_dictionary["train_features"].shape[-1]
model = PyTorchMLPModel(
input_dim=n_features,
output_dim=len(class_names),
**self.model_kwargs
)
model.to(self.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
criterion = torch.nn.CrossEntropyLoss()
init_model = self.get_init_model(dk.pair)
trainer = PyTorchModelTrainer(
model=model,
optimizer=optimizer,
criterion=criterion,
model_meta_data={"class_names": class_names},
device=self.device,
init_model=init_model,
**self.trainer_kwargs,
)
trainer.fit(data_dictionary)
return trainer

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Any, Dict, List, Optional, Tuple from typing import Dict, List, Tuple
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@ -10,17 +10,16 @@ from torch.nn import functional as F
from freqtrade.exceptions import OperationalException from freqtrade.exceptions import OperationalException
from freqtrade.freqai.base_models.BasePyTorchModel import BasePyTorchModel from freqtrade.freqai.base_models.BasePyTorchModel import BasePyTorchModel
from freqtrade.freqai.base_models.PyTorchModelTrainer import PyTorchModelTrainer
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.prediction_models.PyTorchMLPModel import PyTorchMLPModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PyTorchClassifierMultiTarget(BasePyTorchModel): class PyTorchClassifier(BasePyTorchModel):
""" """
A PyTorch implementation of a multi-target classifier. A PyTorch implementation of a classifier.
User must implement fit method
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """
@ -34,59 +33,9 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
""" """
super().__init__(**kwargs) super().__init__(**kwargs)
model_training_params = self.freqai_info.get("model_training_parameters", {})
self.max_iters: int = model_training_params.get("max_iters", 100)
self.batch_size: int = model_training_params.get("batch_size", 64)
self.learning_rate: float = model_training_params.get("learning_rate", 3e-4)
self.max_n_eval_batches: Optional[int] = model_training_params.get(
"max_n_eval_batches", None
)
self.model_kwargs: Dict[str, any] = model_training_params.get("model_kwargs", {})
self.class_name_to_index = None self.class_name_to_index = None
self.index_to_class_name = None self.index_to_class_name = None
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
"""
User sets up the training and test data to fit their desired model here
:param data_dictionary: the dictionary constructed by DataHandler to hold
all the training and test data/labels.
:raises ValueError: If self.class_names is not defined in the parent class.
"""
if not hasattr(self, "class_names"):
raise ValueError(
"Missing attribute: self.class_names "
"set self.freqai.class_names = [\"class a\", \"class b\", \"class c\"] "
"inside IStrategy.set_freqai_targets method."
)
self.init_class_names_to_index_mapping(self.class_names)
self.encode_classes_name(data_dictionary, dk)
n_features = data_dictionary["train_features"].shape[-1]
model = PyTorchMLPModel(
input_dim=n_features,
output_dim=len(self.class_names),
**self.model_kwargs
)
model.to(self.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
criterion = torch.nn.CrossEntropyLoss()
init_model = self.get_init_model(dk.pair)
trainer = PyTorchModelTrainer(
model=model,
optimizer=optimizer,
criterion=criterion,
model_meta_data={"class_names": self.class_names},
device=self.device,
batch_size=self.batch_size,
max_iters=self.max_iters,
max_n_eval_batches=self.max_n_eval_batches,
init_model=init_model
)
trainer.fit(data_dictionary)
return trainer
def predict( def predict(
self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs self, unfiltered_df: DataFrame, dk: FreqaiDataKitchen, **kwargs
) -> Tuple[DataFrame, npt.NDArray[np.int_]]: ) -> Tuple[DataFrame, npt.NDArray[np.int_]]:
@ -97,7 +46,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
:pred_df: dataframe containing the predictions :pred_df: dataframe containing the predictions
:do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
data (NaNs) or felt uncertain about data (PCA and DI index) data (NaNs) or felt uncertain about data (PCA and DI index)
:raises ValueError: if 'class_name' doesn't exist in model meta_data. :raises ValueError: if 'class_names' doesn't exist in model meta_data.
""" """
class_names = self.model.model_meta_data.get("class_names", None) class_names = self.model.model_meta_data.get("class_names", None)
@ -106,6 +55,8 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
"Missing class names. " "Missing class names. "
"self.model.model_meta_data[\"class_names\"] is None." "self.model.model_meta_data[\"class_names\"] is None."
) )
if not self.class_name_to_index:
self.init_class_names_to_index_mapping(class_names) self.init_class_names_to_index_mapping(class_names)
dk.find_features(unfiltered_df) dk.find_features(unfiltered_df)
@ -116,49 +67,77 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
dk.data_dictionary["prediction_features"] = filtered_df dk.data_dictionary["prediction_features"] = filtered_df
self.data_cleaning_predict(dk) self.data_cleaning_predict(dk)
dk.data_dictionary["prediction_features"] = torch.tensor( x = torch.from_numpy(dk.data_dictionary["prediction_features"].values)\
dk.data_dictionary["prediction_features"].values .float()\
).float().to(self.device) .to(self.device)
logits = self.model.model(dk.data_dictionary["prediction_features"]) logits = self.model.model(x)
probs = F.softmax(logits, dim=-1) probs = F.softmax(logits, dim=-1)
predicted_classes = torch.argmax(probs, dim=-1) predicted_classes = torch.argmax(probs, dim=-1)
predicted_classes_str = self.decode_classes_name(predicted_classes) predicted_classes_str = self.decode_class_names(predicted_classes)
pred_df_prob = DataFrame(probs.detach().numpy(), columns=class_names) pred_df_prob = DataFrame(probs.detach().numpy(), columns=class_names)
pred_df = DataFrame(predicted_classes_str, columns=[dk.label_list[0]]) pred_df = DataFrame(predicted_classes_str, columns=[dk.label_list[0]])
pred_df = pd.concat([pred_df, pred_df_prob], axis=1) pred_df = pd.concat([pred_df, pred_df_prob], axis=1)
return (pred_df, dk.do_predict) return (pred_df, dk.do_predict)
def encode_classes_name(self, data_dictionary: Dict[str, pd.DataFrame], dk: FreqaiDataKitchen): def encode_class_names(
self,
data_dictionary: Dict[str, pd.DataFrame],
dk: FreqaiDataKitchen,
class_names: List[str],
):
""" """
encode class name str -> int encode class name, str -> int
assuming first column of *_labels data frame to contain class names assuming first column of *_labels data frame to be the target column
containing the class names
""" """
target_column_name = dk.label_list[0] target_column_name = dk.label_list[0]
for split in ["train", "test"]: for split in ["train", "test"]:
label_df = data_dictionary[f"{split}_labels"] label_df = data_dictionary[f"{split}_labels"]
self.assert_valid_class_names(label_df[target_column_name]) self.assert_valid_class_names(label_df[target_column_name], class_names)
label_df[target_column_name] = list( label_df[target_column_name] = list(
map(lambda x: self.class_name_to_index[x], label_df[target_column_name]) map(lambda x: self.class_name_to_index[x], label_df[target_column_name])
) )
def assert_valid_class_names(self, labels: pd.Series): @staticmethod
non_defined_labels = set(labels) - set(self.class_names) def assert_valid_class_names(
target_column: pd.Series,
class_names: List[str]
):
non_defined_labels = set(target_column) - set(class_names)
if len(non_defined_labels) != 0: if len(non_defined_labels) != 0:
raise OperationalException( raise OperationalException(
f"Found non defined labels: {non_defined_labels}, ", f"Found non defined labels: {non_defined_labels}, ",
f"expecting labels: {self.class_names}" f"expecting labels: {class_names}"
) )
def decode_classes_name(self, classes: torch.Tensor) -> List[str]: def decode_class_names(self, class_ints: torch.Tensor) -> List[str]:
""" """
decode class name int -> str decode class name, int -> str
""" """
return list(map(lambda x: self.index_to_class_name[x.item()], classes)) return list(map(lambda x: self.index_to_class_name[x.item()], class_ints))
def init_class_names_to_index_mapping(self, class_names): def init_class_names_to_index_mapping(self, class_names):
self.class_name_to_index = {s: i for i, s in enumerate(class_names)} self.class_name_to_index = {s: i for i, s in enumerate(class_names)}
self.index_to_class_name = {i: s for i, s in enumerate(class_names)} self.index_to_class_name = {i: s for i, s in enumerate(class_names)}
logger.info(f"class_name_to_index: {self.class_name_to_index}") logger.info(f"encoded class name to index: {self.class_name_to_index}")
def convert_label_column_to_int(
self,
data_dictionary: Dict[str, pd.DataFrame],
dk: FreqaiDataKitchen,
class_names: List[str]
):
self.init_class_names_to_index_mapping(class_names)
self.encode_class_names(data_dictionary, dk, class_names)
def get_class_names(self) -> List[str]:
if not hasattr(self, "class_names"):
raise ValueError(
"Missing attribute: self.class_names "
"set self.freqai.class_names = [\"class a\", \"class b\", \"class c\"] "
"inside IStrategy.set_freqai_targets method."
)
return self.class_names

View File

@ -88,10 +88,12 @@ def test_extract_data_and_train_model_Standard(mocker, freqai_conf, model, pca,
if 'PyTorchClassifierMultiTarget' in model: if 'PyTorchClassifierMultiTarget' in model:
model_save_ext = 'zip' model_save_ext = 'zip'
freqai_conf['freqai']['model_training_parameters'].update({ freqai_conf['freqai']['model_training_parameters'].update({
"learning_rate": 3e-4,
"trainer_kwargs": {
"max_iters": 1, "max_iters": 1,
"batch_size": 64, "batch_size": 64,
"learning_rate": 3e-4,
"max_n_eval_batches": 1, "max_n_eval_batches": 1,
},
"model_kwargs": { "model_kwargs": {
"hidden_dim": 32, "hidden_dim": 32,
"dropout_percent": 0.2, "dropout_percent": 0.2,