ad multiclass target names encoder to ints

This commit is contained in:
Yinon Polak 2023-03-08 14:29:38 +02:00
parent 4241bff32a
commit 76fbec0c17
2 changed files with 52 additions and 14 deletions

View File

@ -79,7 +79,8 @@
"test_size": 0.33, "test_size": 0.33,
"random_state": 1 "random_state": 1
}, },
"model_training_parameters": {} "model_training_parameters": {},
"multiclass_target_names": ["down", "neither", "up"]
}, },
"bot_name": "", "bot_name": "",
"force_entry_enable": true, "force_entry_enable": true,

View File

@ -1,6 +1,6 @@
import logging import logging
from typing import Any, Dict, Tuple from typing import Any, Dict, Tuple, List
import numpy.typing as npt import numpy.typing as npt
import numpy as np import numpy as np
@ -9,6 +9,7 @@ import torch
from pandas import DataFrame from pandas import DataFrame
from torch.nn import functional as F from torch.nn import functional as F
from freqtrade.exceptions import OperationalException
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.base_models.BasePyTorchModel import BasePyTorchModel from freqtrade.freqai.base_models.BasePyTorchModel import BasePyTorchModel
@ -23,13 +24,23 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
# todo move to config self.multiclass_names = self.freqai_info["multiclass_target_names"]
self.labels = ['0.0', '1.0', '2.0'] if not self.multiclass_names:
self.n_hidden = 1024 raise OperationalException(
self.max_iters = 100 "Missing 'multiclass_names' in freqai_info,"
self.batch_size = 64 " multi class pytorch model requires predefined list of"
self.learning_rate = 3e-4 " class names matching the strategy being used"
self.eval_iters = 10 )
self.class_name_to_index = {s: i for i, s in enumerate(self.multiclass_names)}
self.index_to_class_name = {i: s for i, s in enumerate(self.multiclass_names)}
model_training_parameters = self.freqai_info["model_training_parameters"]
self.n_hidden = model_training_parameters.get("n_hidden", 1024)
self.max_iters = model_training_parameters.get("max_iters", 100)
self.batch_size = model_training_parameters.get("batch_size", 64)
self.learning_rate = model_training_parameters.get("learning_rate", 3e-4)
self.eval_iters = model_training_parameters.get("eval_iters", 10)
def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any: def fit(self, data_dictionary: Dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
""" """
@ -37,12 +48,13 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
:param tensor_dictionary: the dictionary constructed by DataHandler to hold :param tensor_dictionary: the dictionary constructed by DataHandler to hold
all the training and test data/labels. all the training and test data/labels.
""" """
n_features = data_dictionary['train_features'].shape[-1]
self.encode_classes_name(data_dictionary, dk)
n_features = data_dictionary['train_features'].shape[-1]
model = PyTorchMLPModel( model = PyTorchMLPModel(
input_dim=n_features, input_dim=n_features,
hidden_dim=self.n_hidden, hidden_dim=self.n_hidden,
output_dim=len(self.labels) output_dim=len(self.multiclass_names)
) )
model.to(self.device) model.to(self.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate) optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
@ -87,9 +99,34 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
logits = self.model.model(dk.data_dictionary["prediction_features"]) logits = self.model.model(dk.data_dictionary["prediction_features"])
probs = F.softmax(logits, dim=-1) probs = F.softmax(logits, dim=-1)
label_ints = torch.argmax(probs, dim=-1) predicted_classes = torch.argmax(probs, dim=-1)
predicted_classes_str = self.decode_classes_name(predicted_classes)
pred_df_prob = DataFrame(probs.detach().numpy(), columns=self.labels) pred_df_prob = DataFrame(probs.detach().numpy(), columns=self.multiclass_names)
pred_df = DataFrame(label_ints, columns=dk.label_list).astype(float).astype(str) pred_df = DataFrame(predicted_classes_str, columns=[dk.label_list[0]])
pred_df = pd.concat([pred_df, pred_df_prob], axis=1) pred_df = pd.concat([pred_df, pred_df_prob], axis=1)
return (pred_df, dk.do_predict) return (pred_df, dk.do_predict)
def encode_classes_name(self, data_dictionary: Dict[str, pd.DataFrame], dk: FreqaiDataKitchen):
"""
encode class name str -> int
assuming first column of *_labels data frame to contain class names
"""
target_column_name = dk.label_list[0]
for split in ["train", "test"]:
label_df = data_dictionary[f"{split}_labels"]
self.assert_valid_class_names(label_df[target_column_name])
label_df[target_column_name] = list(
map(lambda x: self.class_name_to_index[x], label_df[target_column_name])
)
def assert_valid_class_names(self, labels: pd.Series):
non_defined_labels = set(labels) - set(self.multiclass_names)
if len(non_defined_labels) != 0:
raise OperationalException(
f"Found non defined labels {non_defined_labels} ",
f"expecting labels {self.multiclass_names}"
)
def decode_classes_name(self, classes: List[int]) -> List[str]:
return list(map(lambda x: self.index_to_class_name[x], classes))