fix pytorch data convertor type hints
This commit is contained in:
		| @@ -1,5 +1,5 @@ | ||||
| from abc import ABC, abstractmethod | ||||
| from typing import Optional, Tuple | ||||
| from typing import List, Optional | ||||
|  | ||||
| import pandas as pd | ||||
| import torch | ||||
| @@ -12,19 +12,17 @@ class PyTorchDataConvertor(ABC): | ||||
|     """ | ||||
|  | ||||
|     @abstractmethod | ||||
|     def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]: | ||||
|     def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]: | ||||
|         """ | ||||
|         :param df: "*_features" dataframe. | ||||
|         :param device: The device to use for training (e.g. 'cpu', 'cuda'). | ||||
|         :returns: tuple of tensors. | ||||
|         """ | ||||
|  | ||||
|     @abstractmethod | ||||
|     def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]: | ||||
|     def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]: | ||||
|         """ | ||||
|         :param df: "*_labels" dataframe. | ||||
|         :param device: The device to use for training (e.g. 'cpu', 'cuda'). | ||||
|         :returns: tuple of tensors. | ||||
|         """ | ||||
|  | ||||
|  | ||||
| @@ -47,14 +45,14 @@ class DefaultPyTorchDataConvertor(PyTorchDataConvertor): | ||||
|         self._target_tensor_type = target_tensor_type | ||||
|         self._squeeze_target_tensor = squeeze_target_tensor | ||||
|  | ||||
|     def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]: | ||||
|     def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]: | ||||
|         x = torch.from_numpy(df.values).float() | ||||
|         if device: | ||||
|             x = x.to(device) | ||||
|  | ||||
|         return x, | ||||
|         return [x] | ||||
|  | ||||
|     def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> Tuple[torch.Tensor, ...]: | ||||
|     def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]: | ||||
|         y = torch.from_numpy(df.values) | ||||
|  | ||||
|         if self._target_tensor_type: | ||||
| @@ -66,4 +64,4 @@ class DefaultPyTorchDataConvertor(PyTorchDataConvertor): | ||||
|         if device: | ||||
|             y = y.to(device) | ||||
|  | ||||
|         return y, | ||||
|         return [y] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user