add pytorch data convertor

This commit is contained in:
Yinon Polak
2023-04-03 15:19:10 +03:00
parent 5a7ca35c6b
commit bd3b70293f
9 changed files with 168 additions and 40 deletions

View File

@@ -4,6 +4,8 @@ import torch
from freqtrade.freqai.base_models.BasePyTorchClassifier import BasePyTorchClassifier
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.torch import PyTorchDataConvertor
from freqtrade.freqai.torch.PyTorchDataConvertor import DefaultPyTorchDataConvertor
from freqtrade.freqai.torch.PyTorchMLPModel import PyTorchMLPModel
from freqtrade.freqai.torch.PyTorchModelTrainer import PyTorchModelTrainer
@@ -38,6 +40,10 @@ class PyTorchMLPClassifier(BasePyTorchClassifier):
}
"""
@property
def data_convertor(self) -> PyTorchDataConvertor:
return DefaultPyTorchDataConvertor(target_tensor_type=torch.long, squeeze_target_tensor=True)
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
config = self.freqai_info.get("model_training_parameters", {})
@@ -72,8 +78,7 @@ class PyTorchMLPClassifier(BasePyTorchClassifier):
model_meta_data={"class_names": class_names},
device=self.device,
init_model=init_model,
target_tensor_type=torch.long,
squeeze_target_tensor=True,
data_convertor=self.data_convertor,
**self.trainer_kwargs,
)
trainer.fit(data_dictionary, self.splits)

View File

@@ -4,6 +4,8 @@ import torch
from freqtrade.freqai.base_models.BasePyTorchRegressor import BasePyTorchRegressor
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.torch import PyTorchDataConvertor
from freqtrade.freqai.torch.PyTorchDataConvertor import DefaultPyTorchDataConvertor
from freqtrade.freqai.torch.PyTorchMLPModel import PyTorchMLPModel
from freqtrade.freqai.torch.PyTorchModelTrainer import PyTorchModelTrainer
@@ -39,6 +41,10 @@ class PyTorchMLPRegressor(BasePyTorchRegressor):
}
"""
@property
def data_convertor(self) -> PyTorchDataConvertor:
return DefaultPyTorchDataConvertor(target_tensor_type=torch.float)
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
config = self.freqai_info.get("model_training_parameters", {})
@@ -69,7 +75,7 @@ class PyTorchMLPRegressor(BasePyTorchRegressor):
criterion=criterion,
device=self.device,
init_model=init_model,
target_tensor_type=torch.float,
data_convertor=self.data_convertor,
**self.trainer_kwargs,
)
trainer.fit(data_dictionary, self.splits)