diff --git a/train_network.py b/train_network.py index fdc466ec..b783379b 100644 --- a/train_network.py +++ b/train_network.py @@ -1,6 +1,7 @@ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from torch.optim import Optimizer from torch.cuda.amp import autocast +from torch.nn.parallel import DistributedDataParallel as DDP from typing import Optional, Union import importlib import argparse