type hints fixes

This commit is contained in:
Yinon Polak 2023-03-06 20:15:36 +02:00
parent 5dd60eda36
commit 4241bff32a

View File

@ -1,6 +1,7 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
from torch.optim import Optimizer
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -15,7 +16,7 @@ class PyTorchModelTrainer:
def __init__( def __init__(
self, self,
model: nn.Module, model: nn.Module,
optimizer: nn.Module, optimizer: Optimizer,
criterion: nn.Module, criterion: nn.Module,
device: str, device: str,
batch_size: int, batch_size: int,