set class names in IStrategy.set_freqai_targets method, also save class name with model meta data

This commit is contained in:
Yinon Polak 2023-03-08 18:36:44 +02:00
parent 7d26df01b8
commit 1597c3aa89
2 changed files with 33 additions and 20 deletions

View File

@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Dict
from typing import Any, Dict
import pandas as pd
import torch
@ -22,11 +22,13 @@ class PyTorchModelTrainer:
batch_size: int,
max_iters: int,
eval_iters: int,
init_model: Dict
init_model: Dict,
model_meta_data: Dict[str, Any] = {},
):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.model_meta_data = model_meta_data
self.device = device
self.max_iters = max_iters
self.batch_size = batch_size
@ -126,6 +128,7 @@ class PyTorchModelTrainer:
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'model_meta_data': self.model_meta_data,
}, path)
def load_from_file(self, path: Path):
@ -135,4 +138,5 @@ class PyTorchModelTrainer:
def load_from_checkpoint(self, checkpoint: Dict):
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.model_meta_data = checkpoint["model_meta_data"]
return self

View File

@ -22,25 +22,14 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.multiclass_names = self.freqai_info.get("multiclass_target_names", None)
logger.info(f"setting multiclass_names: {self.multiclass_names}")
if not self.multiclass_names:
raise OperationalException(
"Missing 'multiclass_names' in freqai_info, "
"multi class pytorch classifier model requires predefined list of "
"class names matching the strategy being used."
)
self.class_name_to_index = {s: i for i, s in enumerate(self.multiclass_names)}
self.index_to_class_name = {i: s for i, s in enumerate(self.multiclass_names)}
logger.info(f"class_name_to_index: {self.class_name_to_index}")
model_training_parameters = self.freqai_info["model_training_parameters"]
self.n_hidden = model_training_parameters.get("n_hidden", 1024)
self.max_iters = model_training_parameters.get("max_iters", 100)
self.batch_size = model_training_parameters.get("batch_size", 64)
self.learning_rate = model_training_parameters.get("learning_rate", 3e-4)
self.eval_iters = model_training_parameters.get("eval_iters", 10)
self.class_name_to_index = None
self.index_to_class_name = None
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
"""
@ -48,12 +37,20 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
:param tensor_dictionary: the dictionary constructed by DataHandler to hold
all the training and test data/labels.
"""
if not hasattr(self, "class_names"):
raise ValueError(
"Missing attribute: self.class_names "
"set self.freqai.class_names = [\"class a\", \"class b\", \"class c\"] "
"inside IStrategy.set_freqai_targets method."
)
self.init_class_names_to_index_mapping(self.class_names)
self.encode_classes_name(data_dictionary, dk)
n_features = data_dictionary['train_features'].shape[-1]
model = PyTorchMLPModel(
input_dim=n_features,
hidden_dim=self.n_hidden,
output_dim=len(self.multiclass_names)
output_dim=len(self.class_names)
)
model.to(self.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
@ -63,6 +60,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
model=model,
optimizer=optimizer,
criterion=criterion,
model_meta_data={"class_names": self.class_names},
device=self.device,
batch_size=self.batch_size,
max_iters=self.max_iters,
@ -83,6 +81,13 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
:do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
data (NaNs) or felt uncertain about data (PCA and DI index)
"""
class_names = self.model.model_meta_data.get("class_names", None)
if not class_names:
raise ValueError(
"Missing class names. "
"self.model.model_meta_data[\"class_names\"] is None."
)
self.init_class_names_to_index_mapping(class_names)
dk.find_features(unfiltered_df)
filtered_df, _ = dk.filter_features(
@ -100,8 +105,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
probs = F.softmax(logits, dim=-1)
predicted_classes = torch.argmax(probs, dim=-1)
predicted_classes_str = self.decode_classes_name(predicted_classes)
pred_df_prob = DataFrame(probs.detach().numpy(), columns=self.multiclass_names)
pred_df_prob = DataFrame(probs.detach().numpy(), columns=class_names)
pred_df = DataFrame(predicted_classes_str, columns=[dk.label_list[0]])
pred_df = pd.concat([pred_df, pred_df_prob], axis=1)
return (pred_df, dk.do_predict)
@ -120,11 +124,11 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
)
def assert_valid_class_names(self, labels: pd.Series):
non_defined_labels = set(labels) - set(self.multiclass_names)
non_defined_labels = set(labels) - set(self.class_names)
if len(non_defined_labels) != 0:
raise OperationalException(
f"Found non defined labels: {non_defined_labels}, ",
f"expecting labels: {self.multiclass_names}"
f"expecting labels: {self.class_names}"
)
def decode_classes_name(self, classes: torch.Tensor) -> List[str]:
@ -132,3 +136,8 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
decode class name int -> str
"""
return list(map(lambda x: self.index_to_class_name[x.item()], classes))
def init_class_names_to_index_mapping(self, class_names):
self.class_name_to_index = {s: i for i, s in enumerate(class_names)}
self.index_to_class_name = {i: s for i, s in enumerate(class_names)}
logger.info(f"class_name_to_index: {self.class_name_to_index}")