fix imports

This commit is contained in:
Yinon Polak 2023-04-03 16:03:15 +03:00
parent bd3b70293f
commit c137666230
6 changed files with 14 additions and 11 deletions

View File

@ -8,7 +8,7 @@ from pandas import DataFrame
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel from freqtrade.freqai.freqai_interface import IFreqaiModel
from freqtrade.freqai.torch import PyTorchDataConvertor from freqtrade.freqai.torch.PyTorchDataConvertor import PyTorchDataConvertor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -4,8 +4,8 @@ import torch
from freqtrade.freqai.base_models.BasePyTorchClassifier import BasePyTorchClassifier from freqtrade.freqai.base_models.BasePyTorchClassifier import BasePyTorchClassifier
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.torch import PyTorchDataConvertor from freqtrade.freqai.torch.PyTorchDataConvertor import (DefaultPyTorchDataConvertor,
from freqtrade.freqai.torch.PyTorchDataConvertor import DefaultPyTorchDataConvertor PyTorchDataConvertor)
from freqtrade.freqai.torch.PyTorchMLPModel import PyTorchMLPModel from freqtrade.freqai.torch.PyTorchMLPModel import PyTorchMLPModel
from freqtrade.freqai.torch.PyTorchModelTrainer import PyTorchModelTrainer from freqtrade.freqai.torch.PyTorchModelTrainer import PyTorchModelTrainer
@ -42,7 +42,10 @@ class PyTorchMLPClassifier(BasePyTorchClassifier):
@property @property
def data_convertor(self) -> PyTorchDataConvertor: def data_convertor(self) -> PyTorchDataConvertor:
return DefaultPyTorchDataConvertor(target_tensor_type=torch.long, squeeze_target_tensor=True) return DefaultPyTorchDataConvertor(
target_tensor_type=torch.long,
squeeze_target_tensor=True
)
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@ -4,8 +4,8 @@ import torch
from freqtrade.freqai.base_models.BasePyTorchRegressor import BasePyTorchRegressor from freqtrade.freqai.base_models.BasePyTorchRegressor import BasePyTorchRegressor
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.torch import PyTorchDataConvertor from freqtrade.freqai.torch.PyTorchDataConvertor import (DefaultPyTorchDataConvertor,
from freqtrade.freqai.torch.PyTorchDataConvertor import DefaultPyTorchDataConvertor PyTorchDataConvertor)
from freqtrade.freqai.torch.PyTorchMLPModel import PyTorchMLPModel from freqtrade.freqai.torch.PyTorchMLPModel import PyTorchMLPModel
from freqtrade.freqai.torch.PyTorchModelTrainer import PyTorchModelTrainer from freqtrade.freqai.torch.PyTorchModelTrainer import PyTorchModelTrainer

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Tuple, List from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn

View File

@ -12,6 +12,7 @@ from torch.utils.data import DataLoader, TensorDataset
from freqtrade.freqai.torch.PyTorchDataConvertor import PyTorchDataConvertor from freqtrade.freqai.torch.PyTorchDataConvertor import PyTorchDataConvertor
from freqtrade.freqai.torch.PyTorchTrainerInterface import PyTorchTrainerInterface from freqtrade.freqai.torch.PyTorchTrainerInterface import PyTorchTrainerInterface
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,12 +1,11 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple from pathlib import Path
from typing import Dict, List
import pandas as pd import pandas as pd
import torch import torch
import torch.nn as nn import torch.nn as nn
from pathlib import Path
class PyTorchTrainerInterface(ABC): class PyTorchTrainerInterface(ABC):
@ -51,4 +50,4 @@ class PyTorchTrainerInterface(ABC):
get_init_model method. get_init_model method.
:checkpoint checkpoint: dict containing the model & optimizer state dicts, :checkpoint checkpoint: dict containing the model & optimizer state dicts,
model_meta_data, etc.. model_meta_data, etc..
""" """