type hints fixes
This commit is contained in:
parent
5dd60eda36
commit
4241bff32a
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -15,7 +16,7 @@ class PyTorchModelTrainer:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: nn.Module,
|
optimizer: Optimizer,
|
||||||
criterion: nn.Module,
|
criterion: nn.Module,
|
||||||
device: str,
|
device: str,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
Loading…
Reference in New Issue
Block a user