From 56bf7611644402996072bd8f909cf828ec7b27cc Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 26 Oct 2024 17:29:24 +0900 Subject: [PATCH] fix errors in SD3 LoRA training with Text Encoders close #1724 --- library/strategy_sd3.py | 26 +++++++++++++------------- sd3_train_network.py | 2 +- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index dd08cf00..a27e99e6 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -68,9 +68,9 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): returned embeddings are not masked """ clip_l, clip_g, t5xxl = models - clip_l: CLIPTextModel - clip_g: CLIPTextModelWithProjection - t5xxl: T5EncoderModel + clip_l: Optional[CLIPTextModel] + clip_g: Optional[CLIPTextModelWithProjection] + t5xxl: Optional[T5EncoderModel] if apply_lg_attn_mask is None: apply_lg_attn_mask = self.apply_lg_attn_mask @@ -84,25 +84,23 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): if not apply_lg_attn_mask: l_attn_mask = None g_attn_mask = None - else: - l_attn_mask = l_attn_mask.to(clip_l.device) - g_attn_mask = g_attn_mask.to(clip_g.device) if not apply_t5_attn_mask: t5_attn_mask = None - else: - t5_attn_mask = t5_attn_mask.to(t5xxl.device) else: l_attn_mask = None g_attn_mask = None t5_attn_mask = None - if l_tokens is None: + if l_tokens is None or clip_l is None: assert g_tokens is None, "g_tokens must be None if l_tokens is None" lg_out = None lg_pooled = None else: with torch.no_grad(): assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None + g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None + prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True) l_pooled = prompt_embeds[0] l_out = prompt_embeds.hidden_states[-2] @@ -114,13 +112,15 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None lg_out = torch.cat([l_out, g_out], dim=-1) - if t5xxl is not None and t5_tokens is not None: + if t5xxl is None or t5_tokens is None: + t5_out = None + else: + t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None with torch.no_grad(): t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True) - else: - t5_out = None - return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] # masks are used for attention masking in transformer + # masks are used for attention masking in transformer + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] def concat_encodings( self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor diff --git a/sd3_train_network.py b/sd3_train_network.py index ecacf16c..129afed5 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -134,7 +134,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): def get_models_for_text_encoding(self, args, accelerator, text_encoders): if args.cache_text_encoder_outputs: if self.train_clip and not self.train_t5xxl: - return text_encoders[0:2] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached + return text_encoders[0:2] + [None] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached else: return None # no text encoders are needed for encoding because both are cached else: