add pytorch data convertor

This commit is contained in:
Yinon Polak
2023-04-03 15:19:10 +03:00
parent 5a7ca35c6b
commit bd3b70293f
9 changed files with 168 additions and 40 deletions

View File

@@ -69,12 +69,11 @@ class BasePyTorchClassifier(BasePyTorchModel):
)
filtered_df = dk.normalize_data_from_metadata(filtered_df)
dk.data_dictionary["prediction_features"] = filtered_df
self.data_cleaning_predict(dk)
x = torch.from_numpy(dk.data_dictionary["prediction_features"].values)\
.float()\
.to(self.device)
x = self.data_convertor.convert_x(
dk.data_dictionary["prediction_features"],
device=self.device
)
logits = self.model.model(x)
probs = F.softmax(logits, dim=-1)
predicted_classes = torch.argmax(probs, dim=-1)

View File

@@ -1,4 +1,5 @@
import logging
from abc import ABC, abstractmethod
from time import time
from typing import Any
@@ -7,15 +8,17 @@ from pandas import DataFrame
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel
from freqtrade.freqai.torch import PyTorchDataConvertor
logger = logging.getLogger(__name__)
class BasePyTorchModel(IFreqaiModel):
class BasePyTorchModel(IFreqaiModel, ABC):
"""
Base class for PyTorch type models.
User *must* inherit from this class and set fit() and predict().
User *must* inherit from this class and set fit() and predict() and
data_convertor property.
"""
def __init__(self, **kwargs):
@@ -69,3 +72,8 @@ class BasePyTorchModel(IFreqaiModel):
f"({end_time - start_time:.2f} secs) --------------------")
return model
@property
@abstractmethod
def data_convertor(self) -> PyTorchDataConvertor:
raise NotImplementedError("Abstract property")

View File

@@ -3,7 +3,6 @@ from typing import Tuple
import numpy as np
import numpy.typing as npt
import torch
from pandas import DataFrame
from freqtrade.freqai.base_models.BasePyTorchModel import BasePyTorchModel
@@ -41,9 +40,12 @@ class BasePyTorchRegressor(BasePyTorchModel):
dk.data_dictionary["prediction_features"] = filtered_df
self.data_cleaning_predict(dk)
x = torch.from_numpy(dk.data_dictionary["prediction_features"].values)\
.float()\
.to(self.device)
x = self.data_convertor.convert_x(
dk.data_dictionary["prediction_features"],
device=self.device
)
logger.info(self.model.model)
logger.info(self.model.model)
y = self.model.model(x)
pred_df = DataFrame(y.detach().numpy(), columns=[dk.label_list[0]])