set class names in IStrategy.set_freqai_targets method, also save class name with model meta data
This commit is contained in:
parent
7d26df01b8
commit
1597c3aa89
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@ -22,11 +22,13 @@ class PyTorchModelTrainer:
|
|||||||
batch_size: int,
|
batch_size: int,
|
||||||
max_iters: int,
|
max_iters: int,
|
||||||
eval_iters: int,
|
eval_iters: int,
|
||||||
init_model: Dict
|
init_model: Dict,
|
||||||
|
model_meta_data: Dict[str, Any] = {},
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.criterion = criterion
|
self.criterion = criterion
|
||||||
|
self.model_meta_data = model_meta_data
|
||||||
self.device = device
|
self.device = device
|
||||||
self.max_iters = max_iters
|
self.max_iters = max_iters
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@ -126,6 +128,7 @@ class PyTorchModelTrainer:
|
|||||||
torch.save({
|
torch.save({
|
||||||
'model_state_dict': self.model.state_dict(),
|
'model_state_dict': self.model.state_dict(),
|
||||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||||
|
'model_meta_data': self.model_meta_data,
|
||||||
}, path)
|
}, path)
|
||||||
|
|
||||||
def load_from_file(self, path: Path):
|
def load_from_file(self, path: Path):
|
||||||
@ -135,4 +138,5 @@ class PyTorchModelTrainer:
|
|||||||
def load_from_checkpoint(self, checkpoint: Dict):
|
def load_from_checkpoint(self, checkpoint: Dict):
|
||||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
self.model_meta_data = checkpoint["model_meta_data"]
|
||||||
return self
|
return self
|
||||||
|
@ -22,25 +22,14 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.multiclass_names = self.freqai_info.get("multiclass_target_names", None)
|
|
||||||
logger.info(f"setting multiclass_names: {self.multiclass_names}")
|
|
||||||
if not self.multiclass_names:
|
|
||||||
raise OperationalException(
|
|
||||||
"Missing 'multiclass_names' in freqai_info, "
|
|
||||||
"multi class pytorch classifier model requires predefined list of "
|
|
||||||
"class names matching the strategy being used."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.class_name_to_index = {s: i for i, s in enumerate(self.multiclass_names)}
|
|
||||||
self.index_to_class_name = {i: s for i, s in enumerate(self.multiclass_names)}
|
|
||||||
logger.info(f"class_name_to_index: {self.class_name_to_index}")
|
|
||||||
|
|
||||||
model_training_parameters = self.freqai_info["model_training_parameters"]
|
model_training_parameters = self.freqai_info["model_training_parameters"]
|
||||||
self.n_hidden = model_training_parameters.get("n_hidden", 1024)
|
self.n_hidden = model_training_parameters.get("n_hidden", 1024)
|
||||||
self.max_iters = model_training_parameters.get("max_iters", 100)
|
self.max_iters = model_training_parameters.get("max_iters", 100)
|
||||||
self.batch_size = model_training_parameters.get("batch_size", 64)
|
self.batch_size = model_training_parameters.get("batch_size", 64)
|
||||||
self.learning_rate = model_training_parameters.get("learning_rate", 3e-4)
|
self.learning_rate = model_training_parameters.get("learning_rate", 3e-4)
|
||||||
self.eval_iters = model_training_parameters.get("eval_iters", 10)
|
self.eval_iters = model_training_parameters.get("eval_iters", 10)
|
||||||
|
self.class_name_to_index = None
|
||||||
|
self.index_to_class_name = None
|
||||||
|
|
||||||
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
@ -48,12 +37,20 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||||||
:param tensor_dictionary: the dictionary constructed by DataHandler to hold
|
:param tensor_dictionary: the dictionary constructed by DataHandler to hold
|
||||||
all the training and test data/labels.
|
all the training and test data/labels.
|
||||||
"""
|
"""
|
||||||
|
if not hasattr(self, "class_names"):
|
||||||
|
raise ValueError(
|
||||||
|
"Missing attribute: self.class_names "
|
||||||
|
"set self.freqai.class_names = [\"class a\", \"class b\", \"class c\"] "
|
||||||
|
"inside IStrategy.set_freqai_targets method."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.init_class_names_to_index_mapping(self.class_names)
|
||||||
self.encode_classes_name(data_dictionary, dk)
|
self.encode_classes_name(data_dictionary, dk)
|
||||||
n_features = data_dictionary['train_features'].shape[-1]
|
n_features = data_dictionary['train_features'].shape[-1]
|
||||||
model = PyTorchMLPModel(
|
model = PyTorchMLPModel(
|
||||||
input_dim=n_features,
|
input_dim=n_features,
|
||||||
hidden_dim=self.n_hidden,
|
hidden_dim=self.n_hidden,
|
||||||
output_dim=len(self.multiclass_names)
|
output_dim=len(self.class_names)
|
||||||
)
|
)
|
||||||
model.to(self.device)
|
model.to(self.device)
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
|
||||||
@ -63,6 +60,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
|
model_meta_data={"class_names": self.class_names},
|
||||||
device=self.device,
|
device=self.device,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
max_iters=self.max_iters,
|
max_iters=self.max_iters,
|
||||||
@ -83,6 +81,13 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||||||
:do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
|
:do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
|
||||||
data (NaNs) or felt uncertain about data (PCA and DI index)
|
data (NaNs) or felt uncertain about data (PCA and DI index)
|
||||||
"""
|
"""
|
||||||
|
class_names = self.model.model_meta_data.get("class_names", None)
|
||||||
|
if not class_names:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing class names. "
|
||||||
|
"self.model.model_meta_data[\"class_names\"] is None."
|
||||||
|
)
|
||||||
|
self.init_class_names_to_index_mapping(class_names)
|
||||||
|
|
||||||
dk.find_features(unfiltered_df)
|
dk.find_features(unfiltered_df)
|
||||||
filtered_df, _ = dk.filter_features(
|
filtered_df, _ = dk.filter_features(
|
||||||
@ -100,8 +105,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||||||
probs = F.softmax(logits, dim=-1)
|
probs = F.softmax(logits, dim=-1)
|
||||||
predicted_classes = torch.argmax(probs, dim=-1)
|
predicted_classes = torch.argmax(probs, dim=-1)
|
||||||
predicted_classes_str = self.decode_classes_name(predicted_classes)
|
predicted_classes_str = self.decode_classes_name(predicted_classes)
|
||||||
|
pred_df_prob = DataFrame(probs.detach().numpy(), columns=class_names)
|
||||||
pred_df_prob = DataFrame(probs.detach().numpy(), columns=self.multiclass_names)
|
|
||||||
pred_df = DataFrame(predicted_classes_str, columns=[dk.label_list[0]])
|
pred_df = DataFrame(predicted_classes_str, columns=[dk.label_list[0]])
|
||||||
pred_df = pd.concat([pred_df, pred_df_prob], axis=1)
|
pred_df = pd.concat([pred_df, pred_df_prob], axis=1)
|
||||||
return (pred_df, dk.do_predict)
|
return (pred_df, dk.do_predict)
|
||||||
@ -120,11 +124,11 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def assert_valid_class_names(self, labels: pd.Series):
|
def assert_valid_class_names(self, labels: pd.Series):
|
||||||
non_defined_labels = set(labels) - set(self.multiclass_names)
|
non_defined_labels = set(labels) - set(self.class_names)
|
||||||
if len(non_defined_labels) != 0:
|
if len(non_defined_labels) != 0:
|
||||||
raise OperationalException(
|
raise OperationalException(
|
||||||
f"Found non defined labels: {non_defined_labels}, ",
|
f"Found non defined labels: {non_defined_labels}, ",
|
||||||
f"expecting labels: {self.multiclass_names}"
|
f"expecting labels: {self.class_names}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def decode_classes_name(self, classes: torch.Tensor) -> List[str]:
|
def decode_classes_name(self, classes: torch.Tensor) -> List[str]:
|
||||||
@ -132,3 +136,8 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||||||
decode class name int -> str
|
decode class name int -> str
|
||||||
"""
|
"""
|
||||||
return list(map(lambda x: self.index_to_class_name[x.item()], classes))
|
return list(map(lambda x: self.index_to_class_name[x.item()], classes))
|
||||||
|
|
||||||
|
def init_class_names_to_index_mapping(self, class_names):
|
||||||
|
self.class_name_to_index = {s: i for i, s in enumerate(class_names)}
|
||||||
|
self.index_to_class_name = {i: s for i, s in enumerate(class_names)}
|
||||||
|
logger.info(f"class_name_to_index: {self.class_name_to_index}")
|
||||||
|
Loading…
Reference in New Issue
Block a user