fix imports
This commit is contained in:
parent
bd3b70293f
commit
c137666230
@ -8,7 +8,7 @@ from pandas import DataFrame
|
|||||||
|
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
from freqtrade.freqai.freqai_interface import IFreqaiModel
|
||||||
from freqtrade.freqai.torch import PyTorchDataConvertor
|
from freqtrade.freqai.torch.PyTorchDataConvertor import PyTorchDataConvertor
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -4,8 +4,8 @@ import torch
|
|||||||
|
|
||||||
from freqtrade.freqai.base_models.BasePyTorchClassifier import BasePyTorchClassifier
|
from freqtrade.freqai.base_models.BasePyTorchClassifier import BasePyTorchClassifier
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.torch import PyTorchDataConvertor
|
from freqtrade.freqai.torch.PyTorchDataConvertor import (DefaultPyTorchDataConvertor,
|
||||||
from freqtrade.freqai.torch.PyTorchDataConvertor import DefaultPyTorchDataConvertor
|
PyTorchDataConvertor)
|
||||||
from freqtrade.freqai.torch.PyTorchMLPModel import PyTorchMLPModel
|
from freqtrade.freqai.torch.PyTorchMLPModel import PyTorchMLPModel
|
||||||
from freqtrade.freqai.torch.PyTorchModelTrainer import PyTorchModelTrainer
|
from freqtrade.freqai.torch.PyTorchModelTrainer import PyTorchModelTrainer
|
||||||
|
|
||||||
@ -42,7 +42,10 @@ class PyTorchMLPClassifier(BasePyTorchClassifier):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def data_convertor(self) -> PyTorchDataConvertor:
|
def data_convertor(self) -> PyTorchDataConvertor:
|
||||||
return DefaultPyTorchDataConvertor(target_tensor_type=torch.long, squeeze_target_tensor=True)
|
return DefaultPyTorchDataConvertor(
|
||||||
|
target_tensor_type=torch.long,
|
||||||
|
squeeze_target_tensor=True
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, **kwargs) -> None:
|
def __init__(self, **kwargs) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
@ -4,8 +4,8 @@ import torch
|
|||||||
|
|
||||||
from freqtrade.freqai.base_models.BasePyTorchRegressor import BasePyTorchRegressor
|
from freqtrade.freqai.base_models.BasePyTorchRegressor import BasePyTorchRegressor
|
||||||
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
|
||||||
from freqtrade.freqai.torch import PyTorchDataConvertor
|
from freqtrade.freqai.torch.PyTorchDataConvertor import (DefaultPyTorchDataConvertor,
|
||||||
from freqtrade.freqai.torch.PyTorchDataConvertor import DefaultPyTorchDataConvertor
|
PyTorchDataConvertor)
|
||||||
from freqtrade.freqai.torch.PyTorchMLPModel import PyTorchMLPModel
|
from freqtrade.freqai.torch.PyTorchMLPModel import PyTorchMLPModel
|
||||||
from freqtrade.freqai.torch.PyTorchModelTrainer import PyTorchModelTrainer
|
from freqtrade.freqai.torch.PyTorchModelTrainer import PyTorchModelTrainer
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Tuple, List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -12,6 +12,7 @@ from torch.utils.data import DataLoader, TensorDataset
|
|||||||
from freqtrade.freqai.torch.PyTorchDataConvertor import PyTorchDataConvertor
|
from freqtrade.freqai.torch.PyTorchDataConvertor import PyTorchDataConvertor
|
||||||
from freqtrade.freqai.torch.PyTorchTrainerInterface import PyTorchTrainerInterface
|
from freqtrade.freqai.torch.PyTorchTrainerInterface import PyTorchTrainerInterface
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
class PyTorchTrainerInterface(ABC):
|
class PyTorchTrainerInterface(ABC):
|
||||||
|
|
||||||
@ -51,4 +50,4 @@ class PyTorchTrainerInterface(ABC):
|
|||||||
get_init_model method.
|
get_init_model method.
|
||||||
:checkpoint checkpoint: dict containing the model & optimizer state dicts,
|
:checkpoint checkpoint: dict containing the model & optimizer state dicts,
|
||||||
model_meta_data, etc..
|
model_meta_data, etc..
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user