add pytorch data convertor
This commit is contained in:
@@ -69,12 +69,11 @@ class BasePyTorchClassifier(BasePyTorchModel):
|
||||
)
|
||||
filtered_df = dk.normalize_data_from_metadata(filtered_df)
|
||||
dk.data_dictionary["prediction_features"] = filtered_df
|
||||
|
||||
self.data_cleaning_predict(dk)
|
||||
x = torch.from_numpy(dk.data_dictionary["prediction_features"].values)\
|
||||
.float()\
|
||||
.to(self.device)
|
||||
|
||||
x = self.data_convertor.convert_x(
|
||||
dk.data_dictionary["prediction_features"],
|
||||
device=self.device
|
||||
)
|
||||
logits = self.model.model(x)
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
predicted_classes = torch.argmax(probs, dim=-1)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from time import time
|
||||
from typing import Any
|
||||
|
||||
@@ -7,15 +8,17 @@ from pandas import DataFrame
|
||||
|
||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||
from freqtrade.freqai.torch import PyTorchDataConvertor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BasePyTorchModel(IFreqaiModel):
|
||||
class BasePyTorchModel(IFreqaiModel, ABC):
|
||||
"""
|
||||
Base class for PyTorch type models.
|
||||
User *must* inherit from this class and set fit() and predict().
|
||||
User *must* inherit from this class and set fit() and predict() and
|
||||
data_convertor property.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -69,3 +72,8 @@ class BasePyTorchModel(IFreqaiModel):
|
||||
f"({end_time - start_time:.2f} secs) --------------------")
|
||||
|
||||
return model
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def data_convertor(self) -> PyTorchDataConvertor:
|
||||
raise NotImplementedError("Abstract property")
|
||||
|
@@ -3,7 +3,6 @@ from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from pandas import DataFrame
|
||||
|
||||
from freqtrade.freqai.base_models.BasePyTorchModel import BasePyTorchModel
|
||||
@@ -41,9 +40,12 @@ class BasePyTorchRegressor(BasePyTorchModel):
|
||||
dk.data_dictionary["prediction_features"] = filtered_df
|
||||
|
||||
self.data_cleaning_predict(dk)
|
||||
x = torch.from_numpy(dk.data_dictionary["prediction_features"].values)\
|
||||
.float()\
|
||||
.to(self.device)
|
||||
x = self.data_convertor.convert_x(
|
||||
dk.data_dictionary["prediction_features"],
|
||||
device=self.device
|
||||
)
|
||||
logger.info(self.model.model)
|
||||
logger.info(self.model.model)
|
||||
|
||||
y = self.model.model(x)
|
||||
pred_df = DataFrame(y.detach().numpy(), columns=[dk.label_list[0]])
|
||||
|
Reference in New Issue
Block a user