fix pytorch data convertor type hints

This commit is contained in:
Yinon Polak 2023-04-03 19:02:07 +03:00
parent 0c4574b3b7
commit 6b204c97ed

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional, Tuple from typing import List, Optional
import pandas as pd import pandas as pd
import torch import torch
@ -12,19 +12,17 @@ class PyTorchDataConvertor(ABC):
""" """
@abstractmethod @abstractmethod
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]: def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]:
""" """
:param df: "*_features" dataframe. :param df: "*_features" dataframe.
:param device: The device to use for training (e.g. 'cpu', 'cuda'). :param device: The device to use for training (e.g. 'cpu', 'cuda').
:returns: tuple of tensors.
""" """
@abstractmethod @abstractmethod
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]: def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]:
""" """
:param df: "*_labels" dataframe. :param df: "*_labels" dataframe.
:param device: The device to use for training (e.g. 'cpu', 'cuda'). :param device: The device to use for training (e.g. 'cpu', 'cuda').
:returns: tuple of tensors.
""" """
@ -47,14 +45,14 @@ class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
self._target_tensor_type = target_tensor_type self._target_tensor_type = target_tensor_type
self._squeeze_target_tensor = squeeze_target_tensor self._squeeze_target_tensor = squeeze_target_tensor
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]: def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]:
x = torch.from_numpy(df.values).float() x = torch.from_numpy(df.values).float()
if device: if device:
x = x.to(device) x = x.to(device)
return x, return [x]
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]: def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]:
y = torch.from_numpy(df.values) y = torch.from_numpy(df.values)
if self._target_tensor_type: if self._target_tensor_type:
@ -66,4 +64,4 @@ class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
if device: if device:
y = y.to(device) y = y.to(device)
return y, return [y]