support DistributedDataParallel

This commit is contained in:
Isotr0py
2023-02-08 18:54:55 +08:00
parent 938bd71844
commit fb312acb7f

View File

@@ -267,6 +267,14 @@ def train(args):
unet.eval()
text_encoder.eval()
# support DistributedDataParallel
try:
text_encoder = text_encoder.module
unet = unet.module
network = network.module
except:
pass
network.prepare_grad_etc(text_encoder, unet)
if not cache_latents: