fix imports
This commit is contained in:
		| @@ -8,7 +8,7 @@ from pandas import DataFrame | ||||
|  | ||||
| from freqtrade.freqai.data_kitchen import FreqaiDataKitchen | ||||
| from freqtrade.freqai.freqai_interface import IFreqaiModel | ||||
| from freqtrade.freqai.torch import PyTorchDataConvertor | ||||
| from freqtrade.freqai.torch.PyTorchDataConvertor import PyTorchDataConvertor | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|   | ||||
| @@ -4,8 +4,8 @@ import torch | ||||
|  | ||||
| from freqtrade.freqai.base_models.BasePyTorchClassifier import BasePyTorchClassifier | ||||
| from freqtrade.freqai.data_kitchen import FreqaiDataKitchen | ||||
| from freqtrade.freqai.torch import PyTorchDataConvertor | ||||
| from freqtrade.freqai.torch.PyTorchDataConvertor import DefaultPyTorchDataConvertor | ||||
| from freqtrade.freqai.torch.PyTorchDataConvertor import (DefaultPyTorchDataConvertor, | ||||
|                                                          PyTorchDataConvertor) | ||||
| from freqtrade.freqai.torch.PyTorchMLPModel import PyTorchMLPModel | ||||
| from freqtrade.freqai.torch.PyTorchModelTrainer import PyTorchModelTrainer | ||||
|  | ||||
| @@ -42,7 +42,10 @@ class PyTorchMLPClassifier(BasePyTorchClassifier): | ||||
|  | ||||
|     @property | ||||
|     def data_convertor(self) -> PyTorchDataConvertor: | ||||
|         return DefaultPyTorchDataConvertor(target_tensor_type=torch.long, squeeze_target_tensor=True) | ||||
|         return DefaultPyTorchDataConvertor( | ||||
|             target_tensor_type=torch.long, | ||||
|             squeeze_target_tensor=True | ||||
|         ) | ||||
|  | ||||
|     def __init__(self, **kwargs) -> None: | ||||
|         super().__init__(**kwargs) | ||||
|   | ||||
| @@ -4,8 +4,8 @@ import torch | ||||
|  | ||||
| from freqtrade.freqai.base_models.BasePyTorchRegressor import BasePyTorchRegressor | ||||
| from freqtrade.freqai.data_kitchen import FreqaiDataKitchen | ||||
| from freqtrade.freqai.torch import PyTorchDataConvertor | ||||
| from freqtrade.freqai.torch.PyTorchDataConvertor import DefaultPyTorchDataConvertor | ||||
| from freqtrade.freqai.torch.PyTorchDataConvertor import (DefaultPyTorchDataConvertor, | ||||
|                                                          PyTorchDataConvertor) | ||||
| from freqtrade.freqai.torch.PyTorchMLPModel import PyTorchMLPModel | ||||
| from freqtrade.freqai.torch.PyTorchModelTrainer import PyTorchModelTrainer | ||||
|  | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| import logging | ||||
| from typing import Tuple, List | ||||
| from typing import List | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|   | ||||
| @@ -12,6 +12,7 @@ from torch.utils.data import DataLoader, TensorDataset | ||||
| from freqtrade.freqai.torch.PyTorchDataConvertor import PyTorchDataConvertor | ||||
| from freqtrade.freqai.torch.PyTorchTrainerInterface import PyTorchTrainerInterface | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -1,12 +1,11 @@ | ||||
| from abc import ABC, abstractmethod | ||||
| from typing import Any, Dict, List, Optional, Tuple | ||||
| from pathlib import Path | ||||
| from typing import Dict, List | ||||
|  | ||||
| import pandas as pd | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from pathlib import Path | ||||
|  | ||||
|  | ||||
| class PyTorchTrainerInterface(ABC): | ||||
|  | ||||
| @@ -51,4 +50,4 @@ class PyTorchTrainerInterface(ABC): | ||||
|         get_init_model method. | ||||
|         :checkpoint checkpoint: dict containing the model & optimizer state dicts, | ||||
|         model_meta_data, etc.. | ||||
|         """ | ||||
|         """ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user