From 2b1a3080e7ddc329e3a3bf59126d8ccac80d0dae Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 12 Feb 2023 15:32:38 +0800 Subject: [PATCH] Add type checking --- train_network.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/train_network.py b/train_network.py index 852aea8e..90771b31 100644 --- a/train_network.py +++ b/train_network.py @@ -267,18 +267,19 @@ def train(args): text_encoder.train() # set top parameter requires_grad = True for gradient checkpointing works - text_encoder.text_model.embeddings.requires_grad_(True) + if type(text_encoder) == DDP: + text_encoder.module.text_model.embeddings.requires_grad_(True) + else: + text_encoder.text_model.embeddings.requires_grad_(True) else: unet.eval() text_encoder.eval() # support DistributedDataParallel - try: - text_encoder = text_encoder.module - unet = unet.module - network = network.module - except: - pass + if type(text_encoder) == DDP: + text_encoder = text_encoder.module + unet = unet.module + network = network.module network.prepare_grad_etc(text_encoder, unet)