set class names in IStrategy.set_freqai_targets method, also save class name with model meta data

This commit is contained in:
Yinon Polak 2023-03-08 18:36:44 +02:00
parent 7d26df01b8
commit 1597c3aa89
2 changed files with 33 additions and 20 deletions

View File

@ -1,6 +1,6 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Any, Dict
import pandas as pd import pandas as pd
import torch import torch
@ -22,11 +22,13 @@ class PyTorchModelTrainer:
batch_size: int, batch_size: int,
max_iters: int, max_iters: int,
eval_iters: int, eval_iters: int,
init_model: Dict init_model: Dict,
model_meta_data: Dict[str, Any] = {},
): ):
self.model = model self.model = model
self.optimizer = optimizer self.optimizer = optimizer
self.criterion = criterion self.criterion = criterion
self.model_meta_data = model_meta_data
self.device = device self.device = device
self.max_iters = max_iters self.max_iters = max_iters
self.batch_size = batch_size self.batch_size = batch_size
@ -126,6 +128,7 @@ class PyTorchModelTrainer:
torch.save({ torch.save({
'model_state_dict': self.model.state_dict(), 'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(),
'model_meta_data': self.model_meta_data,
}, path) }, path)
def load_from_file(self, path: Path): def load_from_file(self, path: Path):
@ -135,4 +138,5 @@ class PyTorchModelTrainer:
def load_from_checkpoint(self, checkpoint: Dict): def load_from_checkpoint(self, checkpoint: Dict):
self.model.load_state_dict(checkpoint['model_state_dict']) self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.model_meta_data = checkpoint["model_meta_data"]
return self return self

View File

@ -22,25 +22,14 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.multiclass_names = self.freqai_info.get("multiclass_target_names", None)
logger.info(f"setting multiclass_names: {self.multiclass_names}")
if not self.multiclass_names:
raise OperationalException(
"Missing 'multiclass_names' in freqai_info, "
"multi class pytorch classifier model requires predefined list of "
"class names matching the strategy being used."
)
self.class_name_to_index = {s: i for i, s in enumerate(self.multiclass_names)}
self.index_to_class_name = {i: s for i, s in enumerate(self.multiclass_names)}
logger.info(f"class_name_to_index: {self.class_name_to_index}")
model_training_parameters = self.freqai_info["model_training_parameters"] model_training_parameters = self.freqai_info["model_training_parameters"]
self.n_hidden = model_training_parameters.get("n_hidden", 1024) self.n_hidden = model_training_parameters.get("n_hidden", 1024)
self.max_iters = model_training_parameters.get("max_iters", 100) self.max_iters = model_training_parameters.get("max_iters", 100)
self.batch_size = model_training_parameters.get("batch_size", 64) self.batch_size = model_training_parameters.get("batch_size", 64)
self.learning_rate = model_training_parameters.get("learning_rate", 3e-4) self.learning_rate = model_training_parameters.get("learning_rate", 3e-4)
self.eval_iters = model_training_parameters.get("eval_iters", 10) self.eval_iters = model_training_parameters.get("eval_iters", 10)
self.class_name_to_index = None
self.index_to_class_name = None
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
@ -48,12 +37,20 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
:param tensor_dictionary: the dictionary constructed by DataHandler to hold :param tensor_dictionary: the dictionary constructed by DataHandler to hold
all the training and test data/labels. all the training and test data/labels.
""" """
if not hasattr(self, "class_names"):
raise ValueError(
"Missing attribute: self.class_names "
"set self.freqai.class_names = [\"class a\", \"class b\", \"class c\"] "
"inside IStrategy.set_freqai_targets method."
)
self.init_class_names_to_index_mapping(self.class_names)
self.encode_classes_name(data_dictionary, dk) self.encode_classes_name(data_dictionary, dk)
n_features = data_dictionary['train_features'].shape[-1] n_features = data_dictionary['train_features'].shape[-1]
model = PyTorchMLPModel( model = PyTorchMLPModel(
input_dim=n_features, input_dim=n_features,
hidden_dim=self.n_hidden, hidden_dim=self.n_hidden,
output_dim=len(self.multiclass_names) output_dim=len(self.class_names)
) )
model.to(self.device) model.to(self.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate) optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
@ -63,6 +60,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
criterion=criterion, criterion=criterion,
model_meta_data={"class_names": self.class_names},
device=self.device, device=self.device,
batch_size=self.batch_size, batch_size=self.batch_size,
max_iters=self.max_iters, max_iters=self.max_iters,
@ -83,6 +81,13 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
:do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove :do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
data (NaNs) or felt uncertain about data (PCA and DI index) data (NaNs) or felt uncertain about data (PCA and DI index)
""" """
class_names = self.model.model_meta_data.get("class_names", None)
if not class_names:
raise ValueError(
"Missing class names. "
"self.model.model_meta_data[\"class_names\"] is None."
)
self.init_class_names_to_index_mapping(class_names)
dk.find_features(unfiltered_df) dk.find_features(unfiltered_df)
filtered_df, _ = dk.filter_features( filtered_df, _ = dk.filter_features(
@ -100,8 +105,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
probs = F.softmax(logits, dim=-1) probs = F.softmax(logits, dim=-1)
predicted_classes = torch.argmax(probs, dim=-1) predicted_classes = torch.argmax(probs, dim=-1)
predicted_classes_str = self.decode_classes_name(predicted_classes) predicted_classes_str = self.decode_classes_name(predicted_classes)
pred_df_prob = DataFrame(probs.detach().numpy(), columns=class_names)
pred_df_prob = DataFrame(probs.detach().numpy(), columns=self.multiclass_names)
pred_df = DataFrame(predicted_classes_str, columns=[dk.label_list[0]]) pred_df = DataFrame(predicted_classes_str, columns=[dk.label_list[0]])
pred_df = pd.concat([pred_df, pred_df_prob], axis=1) pred_df = pd.concat([pred_df, pred_df_prob], axis=1)
return (pred_df, dk.do_predict) return (pred_df, dk.do_predict)
@ -120,11 +124,11 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
) )
def assert_valid_class_names(self, labels: pd.Series): def assert_valid_class_names(self, labels: pd.Series):
non_defined_labels = set(labels) - set(self.multiclass_names) non_defined_labels = set(labels) - set(self.class_names)
if len(non_defined_labels) != 0: if len(non_defined_labels) != 0:
raise OperationalException( raise OperationalException(
f"Found non defined labels: {non_defined_labels}, ", f"Found non defined labels: {non_defined_labels}, ",
f"expecting labels: {self.multiclass_names}" f"expecting labels: {self.class_names}"
) )
def decode_classes_name(self, classes: torch.Tensor) -> List[str]: def decode_classes_name(self, classes: torch.Tensor) -> List[str]:
@ -132,3 +136,8 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
decode class name int -> str decode class name int -> str
""" """
return list(map(lambda x: self.index_to_class_name[x.item()], classes)) return list(map(lambda x: self.index_to_class_name[x.item()], classes))
def init_class_names_to_index_mapping(self, class_names):
self.class_name_to_index = {s: i for i, s in enumerate(class_names)}
self.index_to_class_name = {i: s for i, s in enumerate(class_names)}
logger.info(f"class_name_to_index: {self.class_name_to_index}")