bugfix skip test split when empty
This commit is contained in:
@@ -97,7 +97,7 @@ class BasePyTorchClassifier(BasePyTorchModel):
|
||||
"""
|
||||
|
||||
target_column_name = dk.label_list[0]
|
||||
for split in ["train", "test"]:
|
||||
for split in self.splits:
|
||||
label_df = data_dictionary[f"{split}_labels"]
|
||||
self.assert_valid_class_names(label_df[target_column_name], class_names)
|
||||
label_df[target_column_name] = list(
|
||||
|
@@ -22,6 +22,8 @@ class BasePyTorchModel(IFreqaiModel):
|
||||
super().__init__(config=kwargs["config"])
|
||||
self.dd.model_type = "pytorch"
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
test_size = self.freqai_info.get('data_split_parameters', {}).get('test_size')
|
||||
self.splits = ["train", "test"] if test_size != 0 else ["train"]
|
||||
|
||||
def train(
|
||||
self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs
|
||||
|
Reference in New Issue
Block a user