Improving caching with argument caption_dropout_rate

This commit is contained in:
Duoong
2026-02-07 18:44:22 +07:00
parent 96a3ae2f87
commit d2d111e826
3 changed files with 102 additions and 32 deletions

View File

@@ -265,6 +265,12 @@ def train(args):
enable_dropout=False,
)
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
if caption_dropout_rate > 0.0:
with accelerator.autocast():
text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
accelerator.wait_for_everyone()
# free text encoder memory

View File

@@ -236,6 +236,14 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
text_encoding_strategy_for_uncond = strategy_base.TextEncodingStrategy.get_strategy()
if caption_dropout_rate > 0.0:
tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy()
with accelerator.autocast():
text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders)
accelerator.wait_for_everyone()
# move text encoder back to cpu

View File

@@ -89,6 +89,34 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
dropout_rate: float = 0.0,
) -> None:
self.dropout_rate = dropout_rate
# Cached unconditional embeddings (from encoding empty caption "")
# Must be initialized via cache_uncond_embeddings() before text encoder is deleted
self._uncond_prompt_embeds: Optional[torch.Tensor] = None # (1, seq_len, hidden)
self._uncond_attn_mask: Optional[torch.Tensor] = None # (1, seq_len)
self._uncond_t5_input_ids: Optional[torch.Tensor] = None # (1, t5_seq_len)
self._uncond_t5_attn_mask: Optional[torch.Tensor] = None # (1, t5_seq_len)
def cache_uncond_embeddings(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
) -> None:
"""Pre-encode empty caption "" and cache the unconditional embeddings.
Must be called before the text encoder is deleted from GPU.
This matches diffusion-pipe-main behavior where empty caption embeddings
are pre-cached and swapped in during caption dropout.
"""
logger.info("Caching unconditional embeddings for caption dropout (encoding empty caption)...")
tokens = tokenize_strategy.tokenize("")
with torch.no_grad():
uncond_outputs = self.encode_tokens(tokenize_strategy, models, tokens, enable_dropout=False)
# Store as CPU tensors (1, seq_len, ...) to avoid GPU memory waste
self._uncond_prompt_embeds = uncond_outputs[0].cpu()
self._uncond_attn_mask = uncond_outputs[1].cpu()
self._uncond_t5_input_ids = uncond_outputs[2].cpu()
self._uncond_t5_attn_mask = uncond_outputs[3].cpu()
logger.info(" Unconditional embeddings cached successfully")
def encode_tokens(
self,
@@ -110,7 +138,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
qwen3_text_encoder = models[0]
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
# Handle dropout (zero out entire batch items)
# Handle dropout: replace dropped items with unconditional embeddings (matching diffusion-pipe-main)
batch_size = qwen3_input_ids.shape[0]
non_drop_indices = []
for i in range(batch_size):
@@ -118,53 +146,68 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
if not drop:
non_drop_indices.append(i)
encoder_device = qwen3_text_encoder.device if hasattr(qwen3_text_encoder, 'device') else next(qwen3_text_encoder.parameters()).device
if len(non_drop_indices) > 0 and len(non_drop_indices) < batch_size:
# Only encode non-dropped items
nd_input_ids = qwen3_input_ids[non_drop_indices]
nd_attn_mask = qwen3_attn_mask[non_drop_indices]
# Only encode non-dropped items to save compute
nd_input_ids = qwen3_input_ids[non_drop_indices].to(encoder_device)
nd_attn_mask = qwen3_attn_mask[non_drop_indices].to(encoder_device)
elif len(non_drop_indices) == batch_size:
nd_input_ids = qwen3_input_ids
nd_attn_mask = qwen3_attn_mask
nd_input_ids = qwen3_input_ids.to(encoder_device)
nd_attn_mask = qwen3_attn_mask.to(encoder_device)
else:
nd_input_ids = None
nd_attn_mask = None
if nd_input_ids is not None:
nd_input_ids = nd_input_ids.to(qwen3_text_encoder.device if hasattr(qwen3_text_encoder, 'device') else next(qwen3_text_encoder.parameters()).device)
nd_attn_mask = nd_attn_mask.to(nd_input_ids.device)
outputs = qwen3_text_encoder(input_ids=nd_input_ids, attention_mask=nd_attn_mask)
nd_encoded_text = outputs.last_hidden_state
# Zero out padding positions
nd_encoded_text[~nd_attn_mask.bool()] = 0
# Fill back dropped items
# Build full batch: fill non-dropped with encoded, dropped with unconditional
if len(non_drop_indices) == batch_size:
prompt_embeds = nd_encoded_text
attn_mask = qwen3_attn_mask
attn_mask = qwen3_attn_mask.to(encoder_device)
else:
# Get unconditional embeddings
if self._uncond_prompt_embeds is not None:
uncond_pe = self._uncond_prompt_embeds[0]
uncond_am = self._uncond_attn_mask[0]
uncond_t5_ids = self._uncond_t5_input_ids[0]
uncond_t5_am = self._uncond_t5_attn_mask[0]
else:
# Encode empty caption on-the-fly (text encoder still available)
uncond_tokens = tokenize_strategy.tokenize("")
uncond_ids = uncond_tokens[0].to(encoder_device)
uncond_mask = uncond_tokens[1].to(encoder_device)
uncond_out = qwen3_text_encoder(input_ids=uncond_ids, attention_mask=uncond_mask)
uncond_pe = uncond_out.last_hidden_state[0]
uncond_pe[~uncond_mask[0].bool()] = 0
uncond_am = uncond_mask[0]
uncond_t5_ids = uncond_tokens[2][0]
uncond_t5_am = uncond_tokens[3][0]
seq_len = qwen3_input_ids.shape[1]
hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else 1024
placeholder_dtype = nd_encoded_text.dtype if nd_encoded_text is not None else torch.float32
device = nd_input_ids.device if nd_input_ids is not None else qwen3_input_ids.device
prompt_embeds = torch.zeros(
(batch_size, seq_len, hidden_size), device=device, dtype=placeholder_dtype
)
attn_mask = torch.zeros(
(batch_size, seq_len), device=device, dtype=qwen3_attn_mask.dtype
)
hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1]
dtype = nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype
prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype)
attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype)
if len(non_drop_indices) > 0:
prompt_embeds[non_drop_indices] = nd_encoded_text
attn_mask[non_drop_indices] = nd_attn_mask
# Zero out t5_input_ids and t5_attn_mask for dropped items
# so the LLM adapter sees a consistent unconditional signal
# Fill dropped items with unconditional embeddings
t5_input_ids = t5_input_ids.clone()
t5_attn_mask = t5_attn_mask.clone()
drop_indices = [i for i in range(batch_size) if i not in non_drop_indices]
for i in drop_indices:
t5_input_ids[i] = torch.zeros_like(t5_input_ids[i])
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
prompt_embeds[i] = uncond_pe.to(device=encoder_device, dtype=dtype)
attn_mask[i] = uncond_am.to(device=encoder_device, dtype=qwen3_attn_mask.dtype)
t5_input_ids[i] = uncond_t5_ids.to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
t5_attn_mask[i] = uncond_t5_am.to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
@@ -178,7 +221,8 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
"""Apply dropout to cached text encoder outputs.
Called during training when using cached outputs.
Clones tensors to avoid corrupting cached data.
Replaces dropped items with pre-cached unconditional embeddings (from encoding "")
to match diffusion-pipe-main behavior.
"""
if prompt_embeds is not None and self.dropout_rate > 0.0:
# Clone to avoid in-place modification of cached tensors
@@ -192,13 +236,25 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
for i in range(prompt_embeds.shape[0]):
if random.random() < self.dropout_rate:
prompt_embeds[i] = torch.zeros_like(prompt_embeds[i])
if attn_mask is not None:
attn_mask[i] = torch.zeros_like(attn_mask[i])
if t5_input_ids is not None:
t5_input_ids[i] = torch.zeros_like(t5_input_ids[i])
if t5_attn_mask is not None:
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
if self._uncond_prompt_embeds is not None:
# Use pre-cached unconditional embeddings
prompt_embeds[i] = self._uncond_prompt_embeds[0].to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if attn_mask is not None:
attn_mask[i] = self._uncond_attn_mask[0].to(device=attn_mask.device, dtype=attn_mask.dtype)
if t5_input_ids is not None:
t5_input_ids[i] = self._uncond_t5_input_ids[0].to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
if t5_attn_mask is not None:
t5_attn_mask[i] = self._uncond_t5_attn_mask[0].to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
else:
# Fallback: zero out (should not happen if cache_uncond_embeddings was called)
logger.warning("Unconditional embeddings not cached, falling back to zeros for caption dropout")
prompt_embeds[i] = torch.zeros_like(prompt_embeds[i])
if attn_mask is not None:
attn_mask[i] = torch.zeros_like(attn_mask[i])
if t5_input_ids is not None:
t5_input_ids[i] = torch.zeros_like(t5_input_ids[i])
if t5_attn_mask is not None:
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]