diff --git a/train_network.py b/train_network.py index c2f9cbf6..fc387bc3 100644 --- a/train_network.py +++ b/train_network.py @@ -1,5 +1,6 @@ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from torch.optim import Optimizer +from torch.cuda.amp import autocast from typing import Optional, Union import importlib import argparse