stable/freqtrade/freqai/torch/PyTorchDataConvertor.py

68 lines
2.0 KiB
Python
Raw Normal View History

2023-04-03 12:19:10 +00:00
from abc import ABC, abstractmethod
2023-04-03 16:02:07 +00:00
from typing import List, Optional
2023-04-03 12:19:10 +00:00
import pandas as pd
import torch
class PyTorchDataConvertor(ABC):
2023-04-03 14:06:39 +00:00
"""
This class is responsible for converting `*_features` & `*_labels` pandas dataframes
to pytorch tensors.
"""
2023-04-03 12:19:10 +00:00
@abstractmethod
2023-04-03 16:02:07 +00:00
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]:
2023-04-03 12:19:10 +00:00
"""
:param df: "*_features" dataframe.
2023-04-03 13:36:38 +00:00
:param device: The device to use for training (e.g. 'cpu', 'cuda').
2023-04-03 12:19:10 +00:00
"""
@abstractmethod
2023-04-03 16:02:07 +00:00
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]:
2023-04-03 12:19:10 +00:00
"""
:param df: "*_labels" dataframe.
2023-04-03 13:36:38 +00:00
:param device: The device to use for training (e.g. 'cpu', 'cuda').
2023-04-03 12:19:10 +00:00
"""
class DefaultPyTorchDataConvertor(PyTorchDataConvertor):
2023-04-03 14:06:39 +00:00
"""
A default conversion that keeps features dataframe shapes.
"""
2023-04-03 12:19:10 +00:00
def __init__(
self,
target_tensor_type: Optional[torch.dtype] = None,
squeeze_target_tensor: bool = False
):
"""
:param target_tensor_type: type of target tensor, for classification use
torch.long, for regressor use torch.float or torch.double.
:param squeeze_target_tensor: controls the target shape, used for loss functions
that requires 0D or 1D.
"""
2023-04-03 12:19:10 +00:00
self._target_tensor_type = target_tensor_type
self._squeeze_target_tensor = squeeze_target_tensor
2023-04-03 16:02:07 +00:00
def convert_x(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]:
2023-04-03 12:19:10 +00:00
x = torch.from_numpy(df.values).float()
if device:
x = x.to(device)
2023-04-03 16:02:07 +00:00
return [x]
2023-04-03 12:19:10 +00:00
2023-04-03 16:02:07 +00:00
def convert_y(self, df: pd.DataFrame, device: Optional[str] = None) -> List[torch.Tensor]:
2023-04-03 12:19:10 +00:00
y = torch.from_numpy(df.values)
if self._target_tensor_type:
y = y.to(self._target_tensor_type)
if self._squeeze_target_tensor:
y = y.squeeze()
if device:
y = y.to(device)
2023-04-03 16:02:07 +00:00
return [y]