ad multiclass target names encoder to ints
This commit is contained in:
		| @@ -79,7 +79,8 @@ | ||||
|             "test_size": 0.33, | ||||
|             "random_state": 1 | ||||
|         }, | ||||
|         "model_training_parameters": {} | ||||
|         "model_training_parameters": {}, | ||||
|         "multiclass_target_names": ["down", "neither", "up"] | ||||
|     }, | ||||
|     "bot_name": "", | ||||
|     "force_entry_enable": true, | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| import logging | ||||
|  | ||||
| from typing import Any, Dict, Tuple | ||||
| from typing import Any, Dict, Tuple, List | ||||
| import numpy.typing as npt | ||||
|  | ||||
| import numpy as np | ||||
| @@ -9,6 +9,7 @@ import torch | ||||
| from pandas import DataFrame | ||||
| from torch.nn import functional as F | ||||
|  | ||||
| from freqtrade.exceptions import OperationalException | ||||
| from freqtrade.freqai.data_kitchen import FreqaiDataKitchen | ||||
|  | ||||
| from freqtrade.freqai.base_models.BasePyTorchModel import BasePyTorchModel | ||||
| @@ -23,13 +24,23 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|         super().__init__(**kwargs) | ||||
|         # todo move to config | ||||
|         self.labels = ['0.0', '1.0', '2.0'] | ||||
|         self.n_hidden = 1024 | ||||
|         self.max_iters = 100 | ||||
|         self.batch_size = 64 | ||||
|         self.learning_rate = 3e-4 | ||||
|         self.eval_iters = 10 | ||||
|         self.multiclass_names = self.freqai_info["multiclass_target_names"] | ||||
|         if not self.multiclass_names: | ||||
|             raise OperationalException( | ||||
|                 "Missing 'multiclass_names' in freqai_info," | ||||
|                 " multi class pytorch model requires predefined list of" | ||||
|                 " class names matching the strategy being used" | ||||
|             ) | ||||
|  | ||||
|         self.class_name_to_index = {s: i for i, s in enumerate(self.multiclass_names)} | ||||
|         self.index_to_class_name = {i: s for i, s in enumerate(self.multiclass_names)} | ||||
|  | ||||
|         model_training_parameters = self.freqai_info["model_training_parameters"] | ||||
|         self.n_hidden = model_training_parameters.get("n_hidden", 1024) | ||||
|         self.max_iters = model_training_parameters.get("max_iters", 100) | ||||
|         self.batch_size = model_training_parameters.get("batch_size", 64) | ||||
|         self.learning_rate = model_training_parameters.get("learning_rate", 3e-4) | ||||
|         self.eval_iters = model_training_parameters.get("eval_iters", 10) | ||||
|  | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
| @@ -37,12 +48,13 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): | ||||
|         :param tensor_dictionary: the dictionary constructed by DataHandler to hold | ||||
|                                 all the training and test data/labels. | ||||
|         """ | ||||
|         n_features = data_dictionary['train_features'].shape[-1] | ||||
|  | ||||
|         self.encode_classes_name(data_dictionary, dk) | ||||
|         n_features = data_dictionary['train_features'].shape[-1] | ||||
|         model = PyTorchMLPModel( | ||||
|             input_dim=n_features, | ||||
|             hidden_dim=self.n_hidden, | ||||
|             output_dim=len(self.labels) | ||||
|             output_dim=len(self.multiclass_names) | ||||
|         ) | ||||
|         model.to(self.device) | ||||
|         optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate) | ||||
| @@ -87,9 +99,34 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): | ||||
|  | ||||
|         logits = self.model.model(dk.data_dictionary["prediction_features"]) | ||||
|         probs = F.softmax(logits, dim=-1) | ||||
|         label_ints = torch.argmax(probs, dim=-1) | ||||
|         predicted_classes = torch.argmax(probs, dim=-1) | ||||
|         predicted_classes_str = self.decode_classes_name(predicted_classes) | ||||
|  | ||||
|         pred_df_prob = DataFrame(probs.detach().numpy(), columns=self.labels) | ||||
|         pred_df = DataFrame(label_ints, columns=dk.label_list).astype(float).astype(str) | ||||
|         pred_df_prob = DataFrame(probs.detach().numpy(), columns=self.multiclass_names) | ||||
|         pred_df = DataFrame(predicted_classes_str, columns=[dk.label_list[0]]) | ||||
|         pred_df = pd.concat([pred_df, pred_df_prob], axis=1) | ||||
|         return (pred_df, dk.do_predict) | ||||
|  | ||||
|     def encode_classes_name(self, data_dictionary: Dict[str, pd.DataFrame], dk: FreqaiDataKitchen): | ||||
|         """ | ||||
|         encode class name str -> int | ||||
|         assuming first column of *_labels data frame to contain class names | ||||
|         """ | ||||
|         target_column_name = dk.label_list[0] | ||||
|         for split in ["train", "test"]: | ||||
|             label_df = data_dictionary[f"{split}_labels"] | ||||
|             self.assert_valid_class_names(label_df[target_column_name]) | ||||
|             label_df[target_column_name] = list( | ||||
|                 map(lambda x: self.class_name_to_index[x], label_df[target_column_name]) | ||||
|             ) | ||||
|  | ||||
|     def assert_valid_class_names(self, labels: pd.Series): | ||||
|         non_defined_labels = set(labels) - set(self.multiclass_names) | ||||
|         if len(non_defined_labels) != 0: | ||||
|             raise OperationalException( | ||||
|                 f"Found non defined labels {non_defined_labels} ", | ||||
|                 f"expecting labels {self.multiclass_names}" | ||||
|             ) | ||||
|  | ||||
|     def decode_classes_name(self, classes: List[int]) -> List[str]: | ||||
|         return list(map(lambda x: self.index_to_class_name[x], classes)) | ||||
		Reference in New Issue
	
	Block a user