mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix SD3 LoRA training to work (WIP)
This commit is contained in:
@@ -111,13 +111,13 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
||||
lg_pooled = None
|
||||
else:
|
||||
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
||||
|
||||
|
||||
drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
|
||||
if drop_l:
|
||||
l_pooled = torch.zeros((l_tokens.shape[0], 768), device=l_tokens.device, dtype=l_tokens.dtype)
|
||||
l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=l_tokens.device, dtype=l_tokens.dtype)
|
||||
l_pooled = torch.zeros((l_tokens.shape[0], 768), device=clip_l.device, dtype=clip_l.dtype)
|
||||
l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=clip_l.device, dtype=clip_l.dtype)
|
||||
if l_attn_mask is not None:
|
||||
l_attn_mask = torch.zeros_like(l_attn_mask)
|
||||
l_attn_mask = torch.zeros_like(l_attn_mask, device=clip_l.device)
|
||||
else:
|
||||
l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None
|
||||
prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True)
|
||||
@@ -126,10 +126,10 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
|
||||
if drop_g:
|
||||
g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=g_tokens.device, dtype=g_tokens.dtype)
|
||||
g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=g_tokens.device, dtype=g_tokens.dtype)
|
||||
g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=clip_g.device, dtype=clip_g.dtype)
|
||||
g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=clip_g.device, dtype=clip_g.dtype)
|
||||
if g_attn_mask is not None:
|
||||
g_attn_mask = torch.zeros_like(g_attn_mask)
|
||||
g_attn_mask = torch.zeros_like(g_attn_mask, device=clip_g.device)
|
||||
else:
|
||||
g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None
|
||||
prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True)
|
||||
@@ -144,9 +144,9 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
||||
else:
|
||||
drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
|
||||
if drop_t5:
|
||||
t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5_tokens.device, dtype=t5_tokens.dtype)
|
||||
t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5xxl.device, dtype=t5xxl.dtype)
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask = torch.zeros_like(t5_attn_mask)
|
||||
t5_attn_mask = torch.zeros_like(t5_attn_mask, device=t5xxl.device)
|
||||
else:
|
||||
t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None
|
||||
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True)
|
||||
@@ -187,7 +187,7 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
|
||||
|
||||
return lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask
|
||||
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
||||
|
||||
def concat_encodings(
|
||||
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
|
||||
|
||||
Reference in New Issue
Block a user