bugfix skip test split when empty

This commit is contained in:
Yinon Polak
2023-03-28 14:40:23 +03:00
parent 8903ba5d89
commit 026b6a39a9
5 changed files with 28 additions and 15 deletions

View File

@@ -97,7 +97,7 @@ class BasePyTorchClassifier(BasePyTorchModel):
"""
target_column_name = dk.label_list[0]
for split in ["train", "test"]:
for split in self.splits:
label_df = data_dictionary[f"{split}_labels"]
self.assert_valid_class_names(label_df[target_column_name], class_names)
label_df[target_column_name] = list(

View File

@@ -22,6 +22,8 @@ class BasePyTorchModel(IFreqaiModel):
super().__init__(config=kwargs["config"])
self.dd.model_type = "pytorch"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
test_size = self.freqai_info.get('data_split_parameters', {}).get('test_size')
self.splits = ["train", "test"] if test_size != 0 else ["train"]
def train(
self, unfiltered_df: DataFrame, pair: str, dk: FreqaiDataKitchen, **kwargs