improve documentation
This commit is contained in:
parent
e88a0d5248
commit
8a9f2aedbb
@ -32,6 +32,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
||||
dict: A dictionary mapping class names to their corresponding indices.
|
||||
dict: A dictionary mapping indices to their corresponding class names.
|
||||
"""
|
||||
|
||||
super().__init__(**kwargs)
|
||||
model_training_parameters = self.freqai_info["model_training_parameters"]
|
||||
self.n_hidden = model_training_parameters.get("n_hidden", 1024)
|
||||
@ -50,6 +51,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
||||
:raises ValueError: If self.class_names is not defined in the parent class.
|
||||
|
||||
"""
|
||||
|
||||
if not hasattr(self, "class_names"):
|
||||
raise ValueError(
|
||||
"Missing attribute: self.class_names "
|
||||
@ -93,7 +95,9 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
||||
:pred_df: dataframe containing the predictions
|
||||
:do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
|
||||
data (NaNs) or felt uncertain about data (PCA and DI index)
|
||||
:raises ValueError: if 'class_name' doesn't exist in model meta_data.
|
||||
"""
|
||||
|
||||
class_names = self.model.model_meta_data.get("class_names", None)
|
||||
if not class_names:
|
||||
raise ValueError(
|
||||
@ -128,6 +132,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
||||
encode class name str -> int
|
||||
assuming first column of *_labels data frame to contain class names
|
||||
"""
|
||||
|
||||
target_column_name = dk.label_list[0]
|
||||
for split in ["train", "test"]:
|
||||
label_df = data_dictionary[f"{split}_labels"]
|
||||
@ -148,6 +153,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
||||
"""
|
||||
decode class name int -> str
|
||||
"""
|
||||
|
||||
return list(map(lambda x: self.index_to_class_name[x.item()], classes))
|
||||
|
||||
def init_class_names_to_index_mapping(self, class_names):
|
||||
|
Loading…
Reference in New Issue
Block a user