diff --git a/freqtrade/freqai/base_models/BasePyTorchModel.py b/freqtrade/freqai/base_models/BasePyTorchModel.py index 7b968c762..d017f1fec 100644 --- a/freqtrade/freqai/base_models/BasePyTorchModel.py +++ b/freqtrade/freqai/base_models/BasePyTorchModel.py @@ -8,7 +8,7 @@ from pandas import DataFrame from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.freqai_interface import IFreqaiModel -from freqtrade.freqai.torch import PyTorchDataConvertor +from freqtrade.freqai.torch.PyTorchDataConvertor import PyTorchDataConvertor logger = logging.getLogger(__name__) diff --git a/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py b/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py index 5b7ea462e..8694453be 100644 --- a/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py +++ b/freqtrade/freqai/prediction_models/PyTorchMLPClassifier.py @@ -4,8 +4,8 @@ import torch from freqtrade.freqai.base_models.BasePyTorchClassifier import BasePyTorchClassifier from freqtrade.freqai.data_kitchen import FreqaiDataKitchen -from freqtrade.freqai.torch import PyTorchDataConvertor -from freqtrade.freqai.torch.PyTorchDataConvertor import DefaultPyTorchDataConvertor +from freqtrade.freqai.torch.PyTorchDataConvertor import (DefaultPyTorchDataConvertor, + PyTorchDataConvertor) from freqtrade.freqai.torch.PyTorchMLPModel import PyTorchMLPModel from freqtrade.freqai.torch.PyTorchModelTrainer import PyTorchModelTrainer @@ -42,7 +42,10 @@ class PyTorchMLPClassifier(BasePyTorchClassifier): @property def data_convertor(self) -> PyTorchDataConvertor: - return DefaultPyTorchDataConvertor(target_tensor_type=torch.long, squeeze_target_tensor=True) + return DefaultPyTorchDataConvertor( + target_tensor_type=torch.long, + squeeze_target_tensor=True + ) def __init__(self, **kwargs) -> None: super().__init__(**kwargs) diff --git a/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py b/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py index 326f14994..5ca3486e1 100644 --- a/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py +++ b/freqtrade/freqai/prediction_models/PyTorchMLPRegressor.py @@ -4,8 +4,8 @@ import torch from freqtrade.freqai.base_models.BasePyTorchRegressor import BasePyTorchRegressor from freqtrade.freqai.data_kitchen import FreqaiDataKitchen -from freqtrade.freqai.torch import PyTorchDataConvertor -from freqtrade.freqai.torch.PyTorchDataConvertor import DefaultPyTorchDataConvertor +from freqtrade.freqai.torch.PyTorchDataConvertor import (DefaultPyTorchDataConvertor, + PyTorchDataConvertor) from freqtrade.freqai.torch.PyTorchMLPModel import PyTorchMLPModel from freqtrade.freqai.torch.PyTorchModelTrainer import PyTorchModelTrainer diff --git a/freqtrade/freqai/torch/PyTorchMLPModel.py b/freqtrade/freqai/torch/PyTorchMLPModel.py index 2deffd708..01192e115 100644 --- a/freqtrade/freqai/torch/PyTorchMLPModel.py +++ b/freqtrade/freqai/torch/PyTorchMLPModel.py @@ -1,5 +1,5 @@ import logging -from typing import Tuple, List +from typing import List import torch import torch.nn as nn diff --git a/freqtrade/freqai/torch/PyTorchModelTrainer.py b/freqtrade/freqai/torch/PyTorchModelTrainer.py index ef5c64a8a..09de6f940 100644 --- a/freqtrade/freqai/torch/PyTorchModelTrainer.py +++ b/freqtrade/freqai/torch/PyTorchModelTrainer.py @@ -12,6 +12,7 @@ from torch.utils.data import DataLoader, TensorDataset from freqtrade.freqai.torch.PyTorchDataConvertor import PyTorchDataConvertor from freqtrade.freqai.torch.PyTorchTrainerInterface import PyTorchTrainerInterface + logger = logging.getLogger(__name__) diff --git a/freqtrade/freqai/torch/PyTorchTrainerInterface.py b/freqtrade/freqai/torch/PyTorchTrainerInterface.py index 2924f2ef9..6686555f9 100644 --- a/freqtrade/freqai/torch/PyTorchTrainerInterface.py +++ b/freqtrade/freqai/torch/PyTorchTrainerInterface.py @@ -1,12 +1,11 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple +from pathlib import Path +from typing import Dict, List import pandas as pd import torch import torch.nn as nn -from pathlib import Path - class PyTorchTrainerInterface(ABC): @@ -51,4 +50,4 @@ class PyTorchTrainerInterface(ABC): get_init_model method. :checkpoint checkpoint: dict containing the model & optimizer state dicts, model_meta_data, etc.. - """ \ No newline at end of file + """