diff --git a/train_network.py b/train_network.py index 710055e0..c2f9cbf6 100644 --- a/train_network.py +++ b/train_network.py @@ -267,6 +267,14 @@ def train(args): unet.eval() text_encoder.eval() + # support DistributedDataParallel + try: + text_encoder = text_encoder.module + unet = unet.module + network = network.module + except: + pass + network.prepare_grad_etc(text_encoder, unet) if not cache_latents: