fix pytorch data convertor type hints
This commit is contained in:
parent
0c4574b3b7
commit
6b204c97ed
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@ -12,19 +12,17 @@ class PyTorchDataConvertor(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
|
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
:param df: "*_features" dataframe.
|
:param df: "*_features" dataframe.
|
||||||
:param device: The device to use for training (e.g. 'cpu', 'cuda').
|
:param device: The device to use for training (e.g. 'cpu', 'cuda').
|
||||||
:returns: tuple of tensors.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
|
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
:param df: "*_labels" dataframe.
|
:param df: "*_labels" dataframe.
|
||||||
:param device: The device to use for training (e.g. 'cpu', 'cuda').
|
:param device: The device to use for training (e.g. 'cpu', 'cuda').
|
||||||
:returns: tuple of tensors.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -47,14 +45,14 @@ class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
|
|||||||
self._target_tensor_type = target_tensor_type
|
self._target_tensor_type = target_tensor_type
|
||||||
self._squeeze_target_tensor = squeeze_target_tensor
|
self._squeeze_target_tensor = squeeze_target_tensor
|
||||||
|
|
||||||
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
|
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]:
|
||||||
x = torch.from_numpy(df.values).float()
|
x = torch.from_numpy(df.values).float()
|
||||||
if device:
|
if device:
|
||||||
x = x.to(device)
|
x = x.to(device)
|
||||||
|
|
||||||
return x,
|
return [x]
|
||||||
|
|
||||||
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]:
|
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]:
|
||||||
y = torch.from_numpy(df.values)
|
y = torch.from_numpy(df.values)
|
||||||
|
|
||||||
if self._target_tensor_type:
|
if self._target_tensor_type:
|
||||||
@ -66,4 +64,4 @@ class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
|
|||||||
if device:
|
if device:
|
||||||
y = y.to(device)
|
y = y.to(device)
|
||||||
|
|
||||||
return y,
|
return [y]
|
||||||
|
Loading…
Reference in New Issue
Block a user