mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add dropout rate arguments for CLIP-L, CLIP-G, and T5, fix Text Encoders LoRA not trained
This commit is contained in:
@@ -214,6 +214,24 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
|
||||
action="store_true",
|
||||
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_l_dropout_rate",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Dropout rate for CLIP-L encoder, default is 0.0 / CLIP-Lエンコーダのドロップアウト率、デフォルトは0.0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_g_dropout_rate",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Dropout rate for CLIP-G encoder, default is 0.0 / CLIP-Gエンコーダのドロップアウト率、デフォルトは0.0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5_dropout_rate",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0",
|
||||
)
|
||||
|
||||
# copy from Diffusers
|
||||
parser.add_argument(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import glob
|
||||
import random
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import numpy as np
|
||||
@@ -48,13 +49,23 @@ class Sd3TokenizeStrategy(TokenizeStrategy):
|
||||
|
||||
|
||||
class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self, apply_lg_attn_mask: Optional[bool] = None, apply_t5_attn_mask: Optional[bool] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
apply_lg_attn_mask: Optional[bool] = None,
|
||||
apply_t5_attn_mask: Optional[bool] = None,
|
||||
l_dropout_rate: float = 0.0,
|
||||
g_dropout_rate: float = 0.0,
|
||||
t5_dropout_rate: float = 0.0,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
|
||||
"""
|
||||
self.apply_lg_attn_mask = apply_lg_attn_mask
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
self.l_dropout_rate = l_dropout_rate
|
||||
self.g_dropout_rate = g_dropout_rate
|
||||
self.t5_dropout_rate = t5_dropout_rate
|
||||
|
||||
def encode_tokens(
|
||||
self,
|
||||
@@ -63,6 +74,7 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
||||
tokens: List[torch.Tensor],
|
||||
apply_lg_attn_mask: Optional[bool] = False,
|
||||
apply_t5_attn_mask: Optional[bool] = False,
|
||||
enable_dropout: bool = True,
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
returned embeddings are not masked
|
||||
@@ -91,37 +103,92 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
||||
g_attn_mask = None
|
||||
t5_attn_mask = None
|
||||
|
||||
# dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings
|
||||
|
||||
if l_tokens is None or clip_l is None:
|
||||
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
|
||||
lg_out = None
|
||||
lg_pooled = None
|
||||
else:
|
||||
with torch.no_grad():
|
||||
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
||||
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
||||
|
||||
drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
|
||||
if drop_l:
|
||||
l_pooled = torch.zeros((l_tokens.shape[0], 768), device=l_tokens.device, dtype=l_tokens.dtype)
|
||||
l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=l_tokens.device, dtype=l_tokens.dtype)
|
||||
if l_attn_mask is not None:
|
||||
l_attn_mask = torch.zeros_like(l_attn_mask)
|
||||
else:
|
||||
l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None
|
||||
g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None
|
||||
|
||||
prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True)
|
||||
l_pooled = prompt_embeds[0]
|
||||
l_out = prompt_embeds.hidden_states[-2]
|
||||
|
||||
drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
|
||||
if drop_g:
|
||||
g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=g_tokens.device, dtype=g_tokens.dtype)
|
||||
g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=g_tokens.device, dtype=g_tokens.dtype)
|
||||
if g_attn_mask is not None:
|
||||
g_attn_mask = torch.zeros_like(g_attn_mask)
|
||||
else:
|
||||
g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None
|
||||
prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True)
|
||||
g_pooled = prompt_embeds[0]
|
||||
g_out = prompt_embeds.hidden_states[-2]
|
||||
|
||||
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
|
||||
if t5xxl is None or t5_tokens is None:
|
||||
t5_out = None
|
||||
else:
|
||||
t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None
|
||||
with torch.no_grad():
|
||||
drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
|
||||
if drop_t5:
|
||||
t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5_tokens.device, dtype=t5_tokens.dtype)
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask = torch.zeros_like(t5_attn_mask)
|
||||
else:
|
||||
t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None
|
||||
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True)
|
||||
|
||||
# masks are used for attention masking in transformer
|
||||
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
||||
|
||||
def drop_cached_text_encoder_outputs(
|
||||
self,
|
||||
lg_out: torch.Tensor,
|
||||
t5_out: torch.Tensor,
|
||||
lg_pooled: torch.Tensor,
|
||||
l_attn_mask: torch.Tensor,
|
||||
g_attn_mask: torch.Tensor,
|
||||
t5_attn_mask: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings
|
||||
if lg_out is not None:
|
||||
for i in range(lg_out.shape[0]):
|
||||
drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate
|
||||
if drop_l:
|
||||
lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768])
|
||||
lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768])
|
||||
if l_attn_mask is not None:
|
||||
l_attn_mask[i] = torch.zeros_like(l_attn_mask[i])
|
||||
drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate
|
||||
if drop_g:
|
||||
lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:])
|
||||
lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:])
|
||||
if g_attn_mask is not None:
|
||||
g_attn_mask[i] = torch.zeros_like(g_attn_mask[i])
|
||||
|
||||
if t5_out is not None:
|
||||
for i in range(t5_out.shape[0]):
|
||||
drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate
|
||||
if drop_t5:
|
||||
t5_out[i] = torch.zeros_like(t5_out[i])
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
|
||||
|
||||
return lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask
|
||||
|
||||
def concat_encodings(
|
||||
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -207,8 +274,14 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
# always disable dropout during caching
|
||||
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask
|
||||
tokenize_strategy,
|
||||
models,
|
||||
tokens_and_masks,
|
||||
apply_lg_attn_mask=self.apply_lg_attn_mask,
|
||||
apply_t5_attn_mask=self.apply_t5_attn_mask,
|
||||
enable_dropout=False,
|
||||
)
|
||||
|
||||
if lg_out.dtype == torch.bfloat16:
|
||||
|
||||
Reference in New Issue
Block a user