diff --git a/freqtrade/freqai/base_models/PyTorchModelTrainer.py b/freqtrade/freqai/base_models/PyTorchModelTrainer.py index 02ff35085..fc0a7600e 100644 --- a/freqtrade/freqai/base_models/PyTorchModelTrainer.py +++ b/freqtrade/freqai/base_models/PyTorchModelTrainer.py @@ -1,6 +1,7 @@ import logging from pathlib import Path from typing import Dict +from torch.optim import Optimizer import torch import torch.nn as nn @@ -15,7 +16,7 @@ class PyTorchModelTrainer: def __init__( self, model: nn.Module, - optimizer: nn.Module, + optimizer: Optimizer, criterion: nn.Module, device: str, batch_size: int,