set class names in IStrategy.set_freqai_targets method, also save class name with model meta data
This commit is contained in:
		| @@ -1,6 +1,6 @@ | ||||
| import logging | ||||
| from pathlib import Path | ||||
| from typing import Dict | ||||
| from typing import Any, Dict | ||||
|  | ||||
| import pandas as pd | ||||
| import torch | ||||
| @@ -22,11 +22,13 @@ class PyTorchModelTrainer: | ||||
|             batch_size: int, | ||||
|             max_iters: int, | ||||
|             eval_iters: int, | ||||
|             init_model: Dict | ||||
|             init_model: Dict, | ||||
|             model_meta_data: Dict[str, Any] = {}, | ||||
|     ): | ||||
|         self.model = model | ||||
|         self.optimizer = optimizer | ||||
|         self.criterion = criterion | ||||
|         self.model_meta_data = model_meta_data | ||||
|         self.device = device | ||||
|         self.max_iters = max_iters | ||||
|         self.batch_size = batch_size | ||||
| @@ -126,6 +128,7 @@ class PyTorchModelTrainer: | ||||
|         torch.save({ | ||||
|             'model_state_dict': self.model.state_dict(), | ||||
|             'optimizer_state_dict': self.optimizer.state_dict(), | ||||
|             'model_meta_data': self.model_meta_data, | ||||
|         }, path) | ||||
|  | ||||
|     def load_from_file(self, path: Path): | ||||
| @@ -135,4 +138,5 @@ class PyTorchModelTrainer: | ||||
|     def load_from_checkpoint(self, checkpoint: Dict): | ||||
|         self.model.load_state_dict(checkpoint['model_state_dict']) | ||||
|         self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | ||||
|         self.model_meta_data = checkpoint["model_meta_data"] | ||||
|         return self | ||||
|   | ||||
| @@ -22,25 +22,14 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|         super().__init__(**kwargs) | ||||
|         self.multiclass_names = self.freqai_info.get("multiclass_target_names", None) | ||||
|         logger.info(f"setting multiclass_names: {self.multiclass_names}") | ||||
|         if not self.multiclass_names: | ||||
|             raise OperationalException( | ||||
|                 "Missing 'multiclass_names' in freqai_info, " | ||||
|                 "multi class pytorch classifier model requires predefined list of " | ||||
|                 "class names matching the strategy being used." | ||||
|             ) | ||||
|  | ||||
|         self.class_name_to_index = {s: i for i, s in enumerate(self.multiclass_names)} | ||||
|         self.index_to_class_name = {i: s for i, s in enumerate(self.multiclass_names)} | ||||
|         logger.info(f"class_name_to_index: {self.class_name_to_index}") | ||||
|  | ||||
|         model_training_parameters = self.freqai_info["model_training_parameters"] | ||||
|         self.n_hidden = model_training_parameters.get("n_hidden", 1024) | ||||
|         self.max_iters = model_training_parameters.get("max_iters", 100) | ||||
|         self.batch_size = model_training_parameters.get("batch_size", 64) | ||||
|         self.learning_rate = model_training_parameters.get("learning_rate", 3e-4) | ||||
|         self.eval_iters = model_training_parameters.get("eval_iters", 10) | ||||
|         self.class_name_to_index = None | ||||
|         self.index_to_class_name = None | ||||
|  | ||||
|     def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: | ||||
|         """ | ||||
| @@ -48,12 +37,20 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): | ||||
|         :param tensor_dictionary: the dictionary constructed by DataHandler to hold | ||||
|                                 all the training and test data/labels. | ||||
|         """ | ||||
|         if not hasattr(self, "class_names"): | ||||
|             raise ValueError( | ||||
|                 "Missing attribute: self.class_names " | ||||
|                 "set self.freqai.class_names = [\"class a\", \"class b\", \"class c\"] " | ||||
|                 "inside IStrategy.set_freqai_targets method." | ||||
|             ) | ||||
|  | ||||
|         self.init_class_names_to_index_mapping(self.class_names) | ||||
|         self.encode_classes_name(data_dictionary, dk) | ||||
|         n_features = data_dictionary['train_features'].shape[-1] | ||||
|         model = PyTorchMLPModel( | ||||
|             input_dim=n_features, | ||||
|             hidden_dim=self.n_hidden, | ||||
|             output_dim=len(self.multiclass_names) | ||||
|             output_dim=len(self.class_names) | ||||
|         ) | ||||
|         model.to(self.device) | ||||
|         optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate) | ||||
| @@ -63,6 +60,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): | ||||
|             model=model, | ||||
|             optimizer=optimizer, | ||||
|             criterion=criterion, | ||||
|             model_meta_data={"class_names": self.class_names}, | ||||
|             device=self.device, | ||||
|             batch_size=self.batch_size, | ||||
|             max_iters=self.max_iters, | ||||
| @@ -83,6 +81,13 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): | ||||
|         :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove | ||||
|         data (NaNs) or felt uncertain about data (PCA and DI index) | ||||
|         """ | ||||
|         class_names = self.model.model_meta_data.get("class_names", None) | ||||
|         if not class_names: | ||||
|             raise ValueError( | ||||
|                 "Missing class names. " | ||||
|                 "self.model.model_meta_data[\"class_names\"] is None." | ||||
|             ) | ||||
|         self.init_class_names_to_index_mapping(class_names) | ||||
|  | ||||
|         dk.find_features(unfiltered_df) | ||||
|         filtered_df, _ = dk.filter_features( | ||||
| @@ -100,8 +105,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): | ||||
|         probs = F.softmax(logits, dim=-1) | ||||
|         predicted_classes = torch.argmax(probs, dim=-1) | ||||
|         predicted_classes_str = self.decode_classes_name(predicted_classes) | ||||
|  | ||||
|         pred_df_prob = DataFrame(probs.detach().numpy(), columns=self.multiclass_names) | ||||
|         pred_df_prob = DataFrame(probs.detach().numpy(), columns=class_names) | ||||
|         pred_df = DataFrame(predicted_classes_str, columns=[dk.label_list[0]]) | ||||
|         pred_df = pd.concat([pred_df, pred_df_prob], axis=1) | ||||
|         return (pred_df, dk.do_predict) | ||||
| @@ -120,11 +124,11 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): | ||||
|             ) | ||||
|  | ||||
|     def assert_valid_class_names(self, labels: pd.Series): | ||||
|         non_defined_labels = set(labels) - set(self.multiclass_names) | ||||
|         non_defined_labels = set(labels) - set(self.class_names) | ||||
|         if len(non_defined_labels) != 0: | ||||
|             raise OperationalException( | ||||
|                 f"Found non defined labels: {non_defined_labels}, ", | ||||
|                 f"expecting labels: {self.multiclass_names}" | ||||
|                 f"expecting labels: {self.class_names}" | ||||
|             ) | ||||
|  | ||||
|     def decode_classes_name(self, classes: torch.Tensor) -> List[str]: | ||||
| @@ -132,3 +136,8 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel): | ||||
|         decode class name int -> str | ||||
|         """ | ||||
|         return list(map(lambda x: self.index_to_class_name[x.item()], classes)) | ||||
|  | ||||
|     def init_class_names_to_index_mapping(self, class_names): | ||||
|         self.class_name_to_index = {s: i for i, s in enumerate(class_names)} | ||||
|         self.index_to_class_name = {i: s for i, s in enumerate(class_names)} | ||||
|         logger.info(f"class_name_to_index: {self.class_name_to_index}") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user