fix pytorch data convertor type hints

This commit is contained in:
Yinon Polak 2023-04-03 19:02:07 +03:00
parent 0c4574b3b7
commit 6b204c97ed

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional, Tuple
from typing import List, Optional
import pandas as pd
import torch
@ -12,19 +12,17 @@ class PyTorchDataConvertor(ABC):
"""
@abstractmethod
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]:
"""
:param df: "*_features" dataframe.
:param device: The device to use for training (e.g. 'cpu', 'cuda').
:returns: tuple of tensors.
"""
@abstractmethod
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]:
"""
:param df: "*_labels" dataframe.
:param device: The device to use for training (e.g. 'cpu', 'cuda').
:returns: tuple of tensors.
"""
@ -47,14 +45,14 @@ class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
self._target_tensor_type = target_tensor_type
self._squeeze_target_tensor = squeeze_target_tensor
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]:
x = torch.from_numpy(df.values).float()
if device:
x = x.to(device)
return x,
return [x]
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]:
y = torch.from_numpy(df.values)
if self._target_tensor_type:
@ -66,4 +64,4 @@ class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
if device:
y = y.to(device)
return y,
return [y]