improve documentation
This commit is contained in:
parent
e88a0d5248
commit
8a9f2aedbb
@ -32,6 +32,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||||||
dict: A dictionary mapping class names to their corresponding indices.
|
dict: A dictionary mapping class names to their corresponding indices.
|
||||||
dict: A dictionary mapping indices to their corresponding class names.
|
dict: A dictionary mapping indices to their corresponding class names.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
model_training_parameters = self.freqai_info["model_training_parameters"]
|
model_training_parameters = self.freqai_info["model_training_parameters"]
|
||||||
self.n_hidden = model_training_parameters.get("n_hidden", 1024)
|
self.n_hidden = model_training_parameters.get("n_hidden", 1024)
|
||||||
@ -50,6 +51,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||||||
:raises ValueError: If self.class_names is not defined in the parent class.
|
:raises ValueError: If self.class_names is not defined in the parent class.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not hasattr(self, "class_names"):
|
if not hasattr(self, "class_names"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Missing attribute: self.class_names "
|
"Missing attribute: self.class_names "
|
||||||
@ -93,7 +95,9 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||||||
:pred_df: dataframe containing the predictions
|
:pred_df: dataframe containing the predictions
|
||||||
:do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
|
:do_predict: np.array of 1s and 0s to indicate places where freqai needed to remove
|
||||||
data (NaNs) or felt uncertain about data (PCA and DI index)
|
data (NaNs) or felt uncertain about data (PCA and DI index)
|
||||||
|
:raises ValueError: if 'class_name' doesn't exist in model meta_data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class_names = self.model.model_meta_data.get("class_names", None)
|
class_names = self.model.model_meta_data.get("class_names", None)
|
||||||
if not class_names:
|
if not class_names:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -128,6 +132,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||||||
encode class name str -> int
|
encode class name str -> int
|
||||||
assuming first column of *_labels data frame to contain class names
|
assuming first column of *_labels data frame to contain class names
|
||||||
"""
|
"""
|
||||||
|
|
||||||
target_column_name = dk.label_list[0]
|
target_column_name = dk.label_list[0]
|
||||||
for split in ["train", "test"]:
|
for split in ["train", "test"]:
|
||||||
label_df = data_dictionary[f"{split}_labels"]
|
label_df = data_dictionary[f"{split}_labels"]
|
||||||
@ -148,6 +153,7 @@ class PyTorchClassifierMultiTarget(BasePyTorchModel):
|
|||||||
"""
|
"""
|
||||||
decode class name int -> str
|
decode class name int -> str
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return list(map(lambda x: self.index_to_class_name[x.item()], classes))
|
return list(map(lambda x: self.index_to_class_name[x.item()], classes))
|
||||||
|
|
||||||
def init_class_names_to_index_mapping(self, class_names):
|
def init_class_names_to_index_mapping(self, class_names):
|
||||||
|
Loading…
Reference in New Issue
Block a user