set class names in IStrategy.set_freqai_targets method, also save class name with model meta data
This commit is contained in:
parent
7d26df01b8
commit
1597c3aa89
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
@ -22,11 +22,13 @@ class PyTorchModelTrainer:
|
||||
batch_size: int,
|
||||
max_iters: int,
|
||||
eval_iters: int,
|
||||
init_model: Dict
|
||||
init_model: Dict,
|
||||
model_meta_data: Dict[str, Any] = {},
|
||||
):
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
self.model_meta_data = model_meta_data
|
||||
self.device = device
|
||||
self.max_iters = max_iters
|
||||
self.batch_size = batch_size
|
||||
@ -126,6 +128,7 @@ class PyTorchModelTrainer:
|
||||
torch.save({
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'model_meta_data': self.model_meta_data,
|
||||
}, path)
|
||||
|
||||
def load_from_file(self, path: Path):
|
||||
@ -135,4 +138,5 @@ class PyTorchModelTrainer:
|
||||
def load_from_checkpoint(self, checkpoint: Dict):
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.model_meta_data = checkpoint["model_meta_data"]
|
||||
return self
|
||||
|
@ -22,25 +22,14 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.multiclass_names = self.freqai_info.get("multiclass_target_names", None)
|
||||
logger.info(f"setting multiclass_names: {self.multiclass_names}")
|
||||
if not self.multiclass_names:
|
||||
raise OperationalException(
|
||||
"Missing 'multiclass_names' in freqai_info, "
|
||||
"multi class pytorch classifier model requires predefined list of "
|
||||
"class names matching the strategy being used."
|
||||
)
|
||||
|
||||
self.class_name_to_index = {s: i for i, s in enumerate(self.multiclass_names)}
|
||||
self.index_to_class_name = {i: s for i, s in enumerate(self.multiclass_names)}
|
||||
logger.info(f"class_name_to_index: {self.class_name_to_index}")
|
||||
|
||||
model_training_parameters = self.freqai_info["model_training_parameters"]
|
||||
self.n_hidden = model_training_parameters.get("n_hidden", 1024)
|
||||
self.max_iters = model_training_parameters.get("max_iters", 100)
|
||||
self.batch_size = model_training_parameters.get("batch_size", 64)
|
||||
self.learning_rate = model_training_parameters.get("learning_rate", 3e-4)
|
||||
self.eval_iters = model_training_parameters.get("eval_iters", 10)
|
||||
self.class_name_to_index = None
|
||||
self.index_to_class_name = None
|
||||
|
||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||
"""
|
||||
@ -48,12 +37,20 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
||||
:param tensor_dictionary: the dictionary constructed by DataHandler to hold
|
||||
all the training and test data/labels.
|
||||
"""
|
||||
if not hasattr(self, "class_names"):
|
||||
raise ValueError(
|
||||
"Missing attribute: self.class_names "
|
||||
"set self.freqai.class_names = [\"class a\", \"class b\", \"class c\"] "
|
||||
"inside IStrategy.set_freqai_targets method."
|
||||
)
|
||||
|
||||
self.init_class_names_to_index_mapping(self.class_names)
|
||||
self.encode_classes_name(data_dictionary, dk)
|
||||
n_features = data_dictionary['train_features'].shape[-1]
|
||||
model = PyTorchMLPModel(
|
||||
input_dim=n_features,
|
||||
hidden_dim=self.n_hidden,
|
||||
output_dim=len(self.multiclass_names)
|
||||
output_dim=len(self.class_names)
|
||||
)
|
||||
model.to(self.device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
|
||||
@ -63,6 +60,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
model_meta_data={"class_names": self.class_names},
|
||||
device=self.device,
|
||||
batch_size=self.batch_size,
|
||||
max_iters=self.max_iters,
|
||||
@ -83,6 +81,13 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
||||
:do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
|
||||
data (NaNs) or felt uncertain about data (PCA and DI index)
|
||||
"""
|
||||
class_names = self.model.model_meta_data.get("class_names", None)
|
||||
if not class_names:
|
||||
raise ValueError(
|
||||
"Missing class names. "
|
||||
"self.model.model_meta_data[\"class_names\"] is None."
|
||||
)
|
||||
self.init_class_names_to_index_mapping(class_names)
|
||||
|
||||
dk.find_features(unfiltered_df)
|
||||
filtered_df, _ = dk.filter_features(
|
||||
@ -100,8 +105,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
predicted_classes = torch.argmax(probs, dim=-1)
|
||||
predicted_classes_str = self.decode_classes_name(predicted_classes)
|
||||
|
||||
pred_df_prob = DataFrame(probs.detach().numpy(), columns=self.multiclass_names)
|
||||
pred_df_prob = DataFrame(probs.detach().numpy(), columns=class_names)
|
||||
pred_df = DataFrame(predicted_classes_str, columns=[dk.label_list[0]])
|
||||
pred_df = pd.concat([pred_df, pred_df_prob], axis=1)
|
||||
return (pred_df, dk.do_predict)
|
||||
@ -120,11 +124,11 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
||||
)
|
||||
|
||||
def assert_valid_class_names(self, labels: pd.Series):
|
||||
non_defined_labels = set(labels) - set(self.multiclass_names)
|
||||
non_defined_labels = set(labels) - set(self.class_names)
|
||||
if len(non_defined_labels) != 0:
|
||||
raise OperationalException(
|
||||
f"Found non defined labels: {non_defined_labels}, ",
|
||||
f"expecting labels: {self.multiclass_names}"
|
||||
f"expecting labels: {self.class_names}"
|
||||
)
|
||||
|
||||
def decode_classes_name(self, classes: torch.Tensor) -> List[str]:
|
||||
@ -132,3 +136,8 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
||||
decode class name int -> str
|
||||
"""
|
||||
return list(map(lambda x: self.index_to_class_name[x.item()], classes))
|
||||
|
||||
def init_class_names_to_index_mapping(self, class_names):
|
||||
self.class_name_to_index = {s: i for i, s in enumerate(class_names)}
|
||||
self.index_to_class_name = {i: s for i, s in enumerate(class_names)}
|
||||
logger.info(f"class_name_to_index: {self.class_name_to_index}")
|
||||
|
Loading…
Reference in New Issue
Block a user