mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix errors in SD3 LoRA training with Text Encoders close #1724
This commit is contained in:
@@ -68,9 +68,9 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
|||||||
returned embeddings are not masked
|
returned embeddings are not masked
|
||||||
"""
|
"""
|
||||||
clip_l, clip_g, t5xxl = models
|
clip_l, clip_g, t5xxl = models
|
||||||
clip_l: CLIPTextModel
|
clip_l: Optional[CLIPTextModel]
|
||||||
clip_g: CLIPTextModelWithProjection
|
clip_g: Optional[CLIPTextModelWithProjection]
|
||||||
t5xxl: T5EncoderModel
|
t5xxl: Optional[T5EncoderModel]
|
||||||
|
|
||||||
if apply_lg_attn_mask is None:
|
if apply_lg_attn_mask is None:
|
||||||
apply_lg_attn_mask = self.apply_lg_attn_mask
|
apply_lg_attn_mask = self.apply_lg_attn_mask
|
||||||
@@ -84,25 +84,23 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
|||||||
if not apply_lg_attn_mask:
|
if not apply_lg_attn_mask:
|
||||||
l_attn_mask = None
|
l_attn_mask = None
|
||||||
g_attn_mask = None
|
g_attn_mask = None
|
||||||
else:
|
|
||||||
l_attn_mask = l_attn_mask.to(clip_l.device)
|
|
||||||
g_attn_mask = g_attn_mask.to(clip_g.device)
|
|
||||||
if not apply_t5_attn_mask:
|
if not apply_t5_attn_mask:
|
||||||
t5_attn_mask = None
|
t5_attn_mask = None
|
||||||
else:
|
|
||||||
t5_attn_mask = t5_attn_mask.to(t5xxl.device)
|
|
||||||
else:
|
else:
|
||||||
l_attn_mask = None
|
l_attn_mask = None
|
||||||
g_attn_mask = None
|
g_attn_mask = None
|
||||||
t5_attn_mask = None
|
t5_attn_mask = None
|
||||||
|
|
||||||
if l_tokens is None:
|
if l_tokens is None or clip_l is None:
|
||||||
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
|
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
|
||||||
lg_out = None
|
lg_out = None
|
||||||
lg_pooled = None
|
lg_pooled = None
|
||||||
else:
|
else:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
||||||
|
l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None
|
||||||
|
g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None
|
||||||
|
|
||||||
prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True)
|
prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True)
|
||||||
l_pooled = prompt_embeds[0]
|
l_pooled = prompt_embeds[0]
|
||||||
l_out = prompt_embeds.hidden_states[-2]
|
l_out = prompt_embeds.hidden_states[-2]
|
||||||
@@ -114,13 +112,15 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
|||||||
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
|
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
|
||||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||||
|
|
||||||
if t5xxl is not None and t5_tokens is not None:
|
if t5xxl is None or t5_tokens is None:
|
||||||
|
t5_out = None
|
||||||
|
else:
|
||||||
|
t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True)
|
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True)
|
||||||
else:
|
|
||||||
t5_out = None
|
|
||||||
|
|
||||||
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] # masks are used for attention masking in transformer
|
# masks are used for attention masking in transformer
|
||||||
|
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
||||||
|
|
||||||
def concat_encodings(
|
def concat_encodings(
|
||||||
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
|
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
if self.train_clip and not self.train_t5xxl:
|
if self.train_clip and not self.train_t5xxl:
|
||||||
return text_encoders[0:2] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached
|
return text_encoders[0:2] + [None] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached
|
||||||
else:
|
else:
|
||||||
return None # no text encoders are needed for encoding because both are cached
|
return None # no text encoders are needed for encoding because both are cached
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user