fix imports

This commit is contained in:
Yinon Polak 2023-04-03 16:03:15 +03:00
parent bd3b70293f
commit c137666230
6 changed files with 14 additions and 11 deletions

View File

@ -8,7 +8,7 @@ from pandas import DataFrame
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.freqai_interface import IFreqaiModel
from freqtrade.freqai.torch import PyTorchDataConvertor
from freqtrade.freqai.torch.PyTorchDataConvertor import PyTorchDataConvertor
logger = logging.getLogger(__name__)

View File

@ -4,8 +4,8 @@ import torch
from freqtrade.freqai.base_models.BasePyTorchClassifier import BasePyTorchClassifier
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.torch import PyTorchDataConvertor
from freqtrade.freqai.torch.PyTorchDataConvertor import DefaultPyTorchDataConvertor
from freqtrade.freqai.torch.PyTorchDataConvertor import (DefaultPyTorchDataConvertor,
PyTorchDataConvertor)
from freqtrade.freqai.torch.PyTorchMLPModel import PyTorchMLPModel
from freqtrade.freqai.torch.PyTorchModelTrainer import PyTorchModelTrainer
@ -42,7 +42,10 @@ class PyTorchMLPClassifier(BasePyTorchClassifier):
@property
def data_convertor(self) -> PyTorchDataConvertor:
return DefaultPyTorchDataConvertor(target_tensor_type=torch.long, squeeze_target_tensor=True)
return DefaultPyTorchDataConvertor(
target_tensor_type=torch.long,
squeeze_target_tensor=True
)
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)

View File

@ -4,8 +4,8 @@ import torch
from freqtrade.freqai.base_models.BasePyTorchRegressor import BasePyTorchRegressor
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.torch import PyTorchDataConvertor
from freqtrade.freqai.torch.PyTorchDataConvertor import DefaultPyTorchDataConvertor
from freqtrade.freqai.torch.PyTorchDataConvertor import (DefaultPyTorchDataConvertor,
PyTorchDataConvertor)
from freqtrade.freqai.torch.PyTorchMLPModel import PyTorchMLPModel
from freqtrade.freqai.torch.PyTorchModelTrainer import PyTorchModelTrainer

View File

@ -1,5 +1,5 @@
import logging
from typing import Tuple, List
from typing import List
import torch
import torch.nn as nn

View File

@ -12,6 +12,7 @@ from torch.utils.data import DataLoader, TensorDataset
from freqtrade.freqai.torch.PyTorchDataConvertor import PyTorchDataConvertor
from freqtrade.freqai.torch.PyTorchTrainerInterface import PyTorchTrainerInterface
logger = logging.getLogger(__name__)

View File

@ -1,12 +1,11 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from pathlib import Path
from typing import Dict, List
import pandas as pd
import torch
import torch.nn as nn
from pathlib import Path
class PyTorchTrainerInterface(ABC):
@ -51,4 +50,4 @@ class PyTorchTrainerInterface(ABC):
get_init_model method.
:checkpoint checkpoint: dict containing the model & optimizer state dicts,
model_meta_data, etc..
"""
"""