From 1597c3aa89f2425f7ec076520a837b7582844a54 Mon Sep 17 00:00:00 2001 From: Yinon Polak Date: Wed, 8 Mar 2023 18:36:44 +0200 Subject: [PATCH] set class names in IStrategy.set_freqai_targets method, also save class name with model meta data --- .../freqai/base_models/PyTorchModelTrainer.py | 8 +++- .../PyTorchClassifierMultiTarget.py | 45 +++++++++++-------- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/freqtrade/freqai/base_models/PyTorchModelTrainer.py b/freqtrade/freqai/base_models/PyTorchModelTrainer.py index 464c5dc43..5ebecef34 100644 --- a/freqtrade/freqai/base_models/PyTorchModelTrainer.py +++ b/freqtrade/freqai/base_models/PyTorchModelTrainer.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Dict +from typing import Any, Dict import pandas as pd import torch @@ -22,11 +22,13 @@ class PyTorchModelTrainer: batch_size: int, max_iters: int, eval_iters: int, - init_model: Dict + init_model: Dict, + model_meta_data: Dict[str, Any] = {}, ): self.model = model self.optimizer = optimizer self.criterion = criterion + self.model_meta_data = model_meta_data self.device = device self.max_iters = max_iters self.batch_size = batch_size @@ -126,6 +128,7 @@ class PyTorchModelTrainer: torch.save({ 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), + 'model_meta_data': self.model_meta_data, }, path) def load_from_file(self, path: Path): @@ -135,4 +138,5 @@ class PyTorchModelTrainer: def load_from_checkpoint(self, checkpoint: Dict): self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + self.model_meta_data = checkpoint["model_meta_data"] return self diff --git a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py index f33248e7d..13ec2d0bb 100644 --- a/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py +++ b/freqtrade/freqai/prediction_models/PyTorchClassifierMultiTarget.py @@ -22,25 +22,14 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): def __init__(self, **kwargs): super().__init__(**kwargs) - self.multiclass_names = self.freqai_info.get("multiclass_target_names", None) - logger.info(f"setting multiclass_names: {self.multiclass_names}") - if not self.multiclass_names: - raise OperationalException( - "Missing 'multiclass_names' in freqai_info, " - "multi class pytorch classifier model requires predefined list of " - "class names matching the strategy being used." - ) - - self.class_name_to_index = {s: i for i, s in enumerate(self.multiclass_names)} - self.index_to_class_name = {i: s for i, s in enumerate(self.multiclass_names)} - logger.info(f"class_name_to_index: {self.class_name_to_index}") - model_training_parameters = self.freqai_info["model_training_parameters"] self.n_hidden = model_training_parameters.get("n_hidden", 1024) self.max_iters = model_training_parameters.get("max_iters", 100) self.batch_size = model_training_parameters.get("batch_size", 64) self.learning_rate = model_training_parameters.get("learning_rate", 3e-4) self.eval_iters = model_training_parameters.get("eval_iters", 10) + self.class_name_to_index = None + self.index_to_class_name = None def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: """ @@ -48,12 +37,20 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): :param tensor_dictionary: the dictionary constructed by DataHandler to hold all the training and test data/labels. """ + if not hasattr(self, "class_names"): + raise ValueError( + "Missing attribute: self.class_names " + "set self.freqai.class_names = [\"class a\", \"class b\", \"class c\"] " + "inside IStrategy.set_freqai_targets method." + ) + + self.init_class_names_to_index_mapping(self.class_names) self.encode_classes_name(data_dictionary, dk) n_features = data_dictionary['train_features'].shape[-1] model = PyTorchMLPModel( input_dim=n_features, hidden_dim=self.n_hidden, - output_dim=len(self.multiclass_names) + output_dim=len(self.class_names) ) model.to(self.device) optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate) @@ -63,6 +60,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): model=model, optimizer=optimizer, criterion=criterion, + model_meta_data={"class_names": self.class_names}, device=self.device, batch_size=self.batch_size, max_iters=self.max_iters, @@ -83,6 +81,13 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove data (NaNs) or felt uncertain about data (PCA and DI index) """ + class_names = self.model.model_meta_data.get("class_names", None) + if not class_names: + raise ValueError( + "Missing class names. " + "self.model.model_meta_data[\"class_names\"] is None." + ) + self.init_class_names_to_index_mapping(class_names) dk.find_features(unfiltered_df) filtered_df, _ = dk.filter_features( @@ -100,8 +105,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): probs = F.softmax(logits, dim=-1) predicted_classes = torch.argmax(probs, dim=-1) predicted_classes_str = self.decode_classes_name(predicted_classes) - - pred_df_prob = DataFrame(probs.detach().numpy(), columns=self.multiclass_names) + pred_df_prob = DataFrame(probs.detach().numpy(), columns=class_names) pred_df = DataFrame(predicted_classes_str, columns=[dk.label_list[0]]) pred_df = pd.concat([pred_df, pred_df_prob], axis=1) return (pred_df, dk.do_predict) @@ -120,11 +124,11 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): ) def assert_valid_class_names(self, labels: pd.Series): - non_defined_labels = set(labels) - set(self.multiclass_names) + non_defined_labels = set(labels) - set(self.class_names) if len(non_defined_labels) != 0: raise OperationalException( f"Found non defined labels: {non_defined_labels}, ", - f"expecting labels: {self.multiclass_names}" + f"expecting labels: {self.class_names}" ) def decode_classes_name(self, classes: torch.Tensor) -> List[str]: @@ -132,3 +136,8 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): decode class name int -> str """ return list(map(lambda x: self.index_to_class_name[x.item()], classes)) + + def init_class_names_to_index_mapping(self, class_names): + self.class_name_to_index = {s: i for i, s in enumerate(class_names)} + self.index_to_class_name = {i: s for i, s in enumerate(class_names)} + logger.info(f"class_name_to_index: {self.class_name_to_index}")