From fb312acb7f07605e46662099f62181de197fb490 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 8 Feb 2023 18:54:55 +0800 Subject: [PATCH] support DistributedDataParallel --- train_network.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/train_network.py b/train_network.py index 710055e0..c2f9cbf6 100644 --- a/train_network.py +++ b/train_network.py @@ -267,6 +267,14 @@ def train(args): unet.eval() text_encoder.eval() + # support DistributedDataParallel + try: + text_encoder = text_encoder.module + unet = unet.module + network = network.module + except: + pass + network.prepare_grad_etc(text_encoder, unet) if not cache_latents: