add pytorch data convertor

This commit is contained in:
Yinon Polak
2023-04-03 15:19:10 +03:00
parent 5a7ca35c6b
commit bd3b70293f
9 changed files with 168 additions and 40 deletions

View File

@@ -0,0 +1,56 @@
from abc import ABC, abstractmethod
from typing import Optional, Tuple
import pandas as pd
import torch
class PyTorchDataConvertor(ABC):
@abstractmethod
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
"""
:param df: "*_features" dataframe.
:param device: cpu/gpu.
:returns: tuple of tensors.
"""
@abstractmethod
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
"""
:param df: "*_labels" dataframe.
:param device: cpu/gpu.
:returns: tuple of tensors.
"""
class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
def __init__(
self,
target_tensor_type: Optional[torch.dtype] = None,
squeeze_target_tensor: bool = False
):
self._target_tensor_type = target_tensor_type
self._squeeze_target_tensor = squeeze_target_tensor
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
x = torch.from_numpy(df.values).float()
if device:
x = x.to(device)
return x,
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
y = torch.from_numpy(df.values)
if self._target_tensor_type:
y = y.to(self._target_tensor_type)
if self._squeeze_target_tensor:
y = y.squeeze()
if device:
y = y.to(device)
return y,