mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 00:32:25 +00:00
Improving caching with argument caption_dropout_rate
This commit is contained in:
@@ -265,6 +265,12 @@ def train(args):
|
||||
enable_dropout=False,
|
||||
)
|
||||
|
||||
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
|
||||
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||
if caption_dropout_rate > 0.0:
|
||||
with accelerator.autocast():
|
||||
text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# free text encoder memory
|
||||
|
||||
@@ -236,6 +236,14 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
)
|
||||
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
||||
|
||||
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
|
||||
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||
text_encoding_strategy_for_uncond = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
if caption_dropout_rate > 0.0:
|
||||
tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy()
|
||||
with accelerator.autocast():
|
||||
text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# move text encoder back to cpu
|
||||
|
||||
@@ -89,6 +89,34 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
dropout_rate: float = 0.0,
|
||||
) -> None:
|
||||
self.dropout_rate = dropout_rate
|
||||
# Cached unconditional embeddings (from encoding empty caption "")
|
||||
# Must be initialized via cache_uncond_embeddings() before text encoder is deleted
|
||||
self._uncond_prompt_embeds: Optional[torch.Tensor] = None # (1, seq_len, hidden)
|
||||
self._uncond_attn_mask: Optional[torch.Tensor] = None # (1, seq_len)
|
||||
self._uncond_t5_input_ids: Optional[torch.Tensor] = None # (1, t5_seq_len)
|
||||
self._uncond_t5_attn_mask: Optional[torch.Tensor] = None # (1, t5_seq_len)
|
||||
|
||||
def cache_uncond_embeddings(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
) -> None:
|
||||
"""Pre-encode empty caption "" and cache the unconditional embeddings.
|
||||
|
||||
Must be called before the text encoder is deleted from GPU.
|
||||
This matches diffusion-pipe-main behavior where empty caption embeddings
|
||||
are pre-cached and swapped in during caption dropout.
|
||||
"""
|
||||
logger.info("Caching unconditional embeddings for caption dropout (encoding empty caption)...")
|
||||
tokens = tokenize_strategy.tokenize("")
|
||||
with torch.no_grad():
|
||||
uncond_outputs = self.encode_tokens(tokenize_strategy, models, tokens, enable_dropout=False)
|
||||
# Store as CPU tensors (1, seq_len, ...) to avoid GPU memory waste
|
||||
self._uncond_prompt_embeds = uncond_outputs[0].cpu()
|
||||
self._uncond_attn_mask = uncond_outputs[1].cpu()
|
||||
self._uncond_t5_input_ids = uncond_outputs[2].cpu()
|
||||
self._uncond_t5_attn_mask = uncond_outputs[3].cpu()
|
||||
logger.info(" Unconditional embeddings cached successfully")
|
||||
|
||||
def encode_tokens(
|
||||
self,
|
||||
@@ -110,7 +138,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
qwen3_text_encoder = models[0]
|
||||
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
|
||||
|
||||
# Handle dropout (zero out entire batch items)
|
||||
# Handle dropout: replace dropped items with unconditional embeddings (matching diffusion-pipe-main)
|
||||
batch_size = qwen3_input_ids.shape[0]
|
||||
non_drop_indices = []
|
||||
for i in range(batch_size):
|
||||
@@ -118,53 +146,68 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
if not drop:
|
||||
non_drop_indices.append(i)
|
||||
|
||||
encoder_device = qwen3_text_encoder.device if hasattr(qwen3_text_encoder, 'device') else next(qwen3_text_encoder.parameters()).device
|
||||
|
||||
if len(non_drop_indices) > 0 and len(non_drop_indices) < batch_size:
|
||||
# Only encode non-dropped items
|
||||
nd_input_ids = qwen3_input_ids[non_drop_indices]
|
||||
nd_attn_mask = qwen3_attn_mask[non_drop_indices]
|
||||
# Only encode non-dropped items to save compute
|
||||
nd_input_ids = qwen3_input_ids[non_drop_indices].to(encoder_device)
|
||||
nd_attn_mask = qwen3_attn_mask[non_drop_indices].to(encoder_device)
|
||||
elif len(non_drop_indices) == batch_size:
|
||||
nd_input_ids = qwen3_input_ids
|
||||
nd_attn_mask = qwen3_attn_mask
|
||||
nd_input_ids = qwen3_input_ids.to(encoder_device)
|
||||
nd_attn_mask = qwen3_attn_mask.to(encoder_device)
|
||||
else:
|
||||
nd_input_ids = None
|
||||
nd_attn_mask = None
|
||||
|
||||
if nd_input_ids is not None:
|
||||
nd_input_ids = nd_input_ids.to(qwen3_text_encoder.device if hasattr(qwen3_text_encoder, 'device') else next(qwen3_text_encoder.parameters()).device)
|
||||
nd_attn_mask = nd_attn_mask.to(nd_input_ids.device)
|
||||
|
||||
outputs = qwen3_text_encoder(input_ids=nd_input_ids, attention_mask=nd_attn_mask)
|
||||
nd_encoded_text = outputs.last_hidden_state
|
||||
# Zero out padding positions
|
||||
nd_encoded_text[~nd_attn_mask.bool()] = 0
|
||||
|
||||
# Fill back dropped items
|
||||
# Build full batch: fill non-dropped with encoded, dropped with unconditional
|
||||
if len(non_drop_indices) == batch_size:
|
||||
prompt_embeds = nd_encoded_text
|
||||
attn_mask = qwen3_attn_mask
|
||||
attn_mask = qwen3_attn_mask.to(encoder_device)
|
||||
else:
|
||||
# Get unconditional embeddings
|
||||
if self._uncond_prompt_embeds is not None:
|
||||
uncond_pe = self._uncond_prompt_embeds[0]
|
||||
uncond_am = self._uncond_attn_mask[0]
|
||||
uncond_t5_ids = self._uncond_t5_input_ids[0]
|
||||
uncond_t5_am = self._uncond_t5_attn_mask[0]
|
||||
else:
|
||||
# Encode empty caption on-the-fly (text encoder still available)
|
||||
uncond_tokens = tokenize_strategy.tokenize("")
|
||||
uncond_ids = uncond_tokens[0].to(encoder_device)
|
||||
uncond_mask = uncond_tokens[1].to(encoder_device)
|
||||
uncond_out = qwen3_text_encoder(input_ids=uncond_ids, attention_mask=uncond_mask)
|
||||
uncond_pe = uncond_out.last_hidden_state[0]
|
||||
uncond_pe[~uncond_mask[0].bool()] = 0
|
||||
uncond_am = uncond_mask[0]
|
||||
uncond_t5_ids = uncond_tokens[2][0]
|
||||
uncond_t5_am = uncond_tokens[3][0]
|
||||
|
||||
seq_len = qwen3_input_ids.shape[1]
|
||||
hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else 1024
|
||||
placeholder_dtype = nd_encoded_text.dtype if nd_encoded_text is not None else torch.float32
|
||||
device = nd_input_ids.device if nd_input_ids is not None else qwen3_input_ids.device
|
||||
prompt_embeds = torch.zeros(
|
||||
(batch_size, seq_len, hidden_size), device=device, dtype=placeholder_dtype
|
||||
)
|
||||
attn_mask = torch.zeros(
|
||||
(batch_size, seq_len), device=device, dtype=qwen3_attn_mask.dtype
|
||||
)
|
||||
hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1]
|
||||
dtype = nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype
|
||||
|
||||
prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype)
|
||||
attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype)
|
||||
|
||||
if len(non_drop_indices) > 0:
|
||||
prompt_embeds[non_drop_indices] = nd_encoded_text
|
||||
attn_mask[non_drop_indices] = nd_attn_mask
|
||||
|
||||
# Zero out t5_input_ids and t5_attn_mask for dropped items
|
||||
# so the LLM adapter sees a consistent unconditional signal
|
||||
# Fill dropped items with unconditional embeddings
|
||||
t5_input_ids = t5_input_ids.clone()
|
||||
t5_attn_mask = t5_attn_mask.clone()
|
||||
drop_indices = [i for i in range(batch_size) if i not in non_drop_indices]
|
||||
for i in drop_indices:
|
||||
t5_input_ids[i] = torch.zeros_like(t5_input_ids[i])
|
||||
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
|
||||
prompt_embeds[i] = uncond_pe.to(device=encoder_device, dtype=dtype)
|
||||
attn_mask[i] = uncond_am.to(device=encoder_device, dtype=qwen3_attn_mask.dtype)
|
||||
t5_input_ids[i] = uncond_t5_ids.to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
|
||||
t5_attn_mask[i] = uncond_t5_am.to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
|
||||
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
@@ -178,7 +221,8 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
"""Apply dropout to cached text encoder outputs.
|
||||
|
||||
Called during training when using cached outputs.
|
||||
Clones tensors to avoid corrupting cached data.
|
||||
Replaces dropped items with pre-cached unconditional embeddings (from encoding "")
|
||||
to match diffusion-pipe-main behavior.
|
||||
"""
|
||||
if prompt_embeds is not None and self.dropout_rate > 0.0:
|
||||
# Clone to avoid in-place modification of cached tensors
|
||||
@@ -192,13 +236,25 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
for i in range(prompt_embeds.shape[0]):
|
||||
if random.random() < self.dropout_rate:
|
||||
prompt_embeds[i] = torch.zeros_like(prompt_embeds[i])
|
||||
if attn_mask is not None:
|
||||
attn_mask[i] = torch.zeros_like(attn_mask[i])
|
||||
if t5_input_ids is not None:
|
||||
t5_input_ids[i] = torch.zeros_like(t5_input_ids[i])
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
|
||||
if self._uncond_prompt_embeds is not None:
|
||||
# Use pre-cached unconditional embeddings
|
||||
prompt_embeds[i] = self._uncond_prompt_embeds[0].to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
||||
if attn_mask is not None:
|
||||
attn_mask[i] = self._uncond_attn_mask[0].to(device=attn_mask.device, dtype=attn_mask.dtype)
|
||||
if t5_input_ids is not None:
|
||||
t5_input_ids[i] = self._uncond_t5_input_ids[0].to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask[i] = self._uncond_t5_attn_mask[0].to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
|
||||
else:
|
||||
# Fallback: zero out (should not happen if cache_uncond_embeddings was called)
|
||||
logger.warning("Unconditional embeddings not cached, falling back to zeros for caption dropout")
|
||||
prompt_embeds[i] = torch.zeros_like(prompt_embeds[i])
|
||||
if attn_mask is not None:
|
||||
attn_mask[i] = torch.zeros_like(attn_mask[i])
|
||||
if t5_input_ids is not None:
|
||||
t5_input_ids[i] = torch.zeros_like(t5_input_ids[i])
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
|
||||
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user