diff --git a/anima_train.py b/anima_train.py index d703d02e..9fc1ec80 100644 --- a/anima_train.py +++ b/anima_train.py @@ -265,6 +265,12 @@ def train(args): enable_dropout=False, ) + # Pre-cache unconditional embeddings for caption dropout before text encoder is deleted + caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) + if caption_dropout_rate > 0.0: + with accelerator.autocast(): + text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder]) + accelerator.wait_for_everyone() # free text encoder memory diff --git a/anima_train_network.py b/anima_train_network.py index 2416cd7b..57ad1681 100644 --- a/anima_train_network.py +++ b/anima_train_network.py @@ -236,6 +236,14 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): ) self.sample_prompts_te_outputs = sample_prompts_te_outputs + # Pre-cache unconditional embeddings for caption dropout before text encoder is deleted + caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) + text_encoding_strategy_for_uncond = strategy_base.TextEncodingStrategy.get_strategy() + if caption_dropout_rate > 0.0: + tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy() + with accelerator.autocast(): + text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders) + accelerator.wait_for_everyone() # move text encoder back to cpu diff --git a/library/strategy_anima.py b/library/strategy_anima.py index a2375e88..9c9b0126 100644 --- a/library/strategy_anima.py +++ b/library/strategy_anima.py @@ -89,6 +89,34 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): dropout_rate: float = 0.0, ) -> None: self.dropout_rate = dropout_rate + # Cached unconditional embeddings (from encoding empty caption "") + # Must be initialized via cache_uncond_embeddings() before text encoder is deleted + self._uncond_prompt_embeds: Optional[torch.Tensor] = None # (1, seq_len, hidden) + self._uncond_attn_mask: Optional[torch.Tensor] = None # (1, seq_len) + self._uncond_t5_input_ids: Optional[torch.Tensor] = None # (1, t5_seq_len) + self._uncond_t5_attn_mask: Optional[torch.Tensor] = None # (1, t5_seq_len) + + def cache_uncond_embeddings( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + ) -> None: + """Pre-encode empty caption "" and cache the unconditional embeddings. + + Must be called before the text encoder is deleted from GPU. + This matches diffusion-pipe-main behavior where empty caption embeddings + are pre-cached and swapped in during caption dropout. + """ + logger.info("Caching unconditional embeddings for caption dropout (encoding empty caption)...") + tokens = tokenize_strategy.tokenize("") + with torch.no_grad(): + uncond_outputs = self.encode_tokens(tokenize_strategy, models, tokens, enable_dropout=False) + # Store as CPU tensors (1, seq_len, ...) to avoid GPU memory waste + self._uncond_prompt_embeds = uncond_outputs[0].cpu() + self._uncond_attn_mask = uncond_outputs[1].cpu() + self._uncond_t5_input_ids = uncond_outputs[2].cpu() + self._uncond_t5_attn_mask = uncond_outputs[3].cpu() + logger.info(" Unconditional embeddings cached successfully") def encode_tokens( self, @@ -110,7 +138,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): qwen3_text_encoder = models[0] qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens - # Handle dropout (zero out entire batch items) + # Handle dropout: replace dropped items with unconditional embeddings (matching diffusion-pipe-main) batch_size = qwen3_input_ids.shape[0] non_drop_indices = [] for i in range(batch_size): @@ -118,53 +146,68 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): if not drop: non_drop_indices.append(i) + encoder_device = qwen3_text_encoder.device if hasattr(qwen3_text_encoder, 'device') else next(qwen3_text_encoder.parameters()).device + if len(non_drop_indices) > 0 and len(non_drop_indices) < batch_size: - # Only encode non-dropped items - nd_input_ids = qwen3_input_ids[non_drop_indices] - nd_attn_mask = qwen3_attn_mask[non_drop_indices] + # Only encode non-dropped items to save compute + nd_input_ids = qwen3_input_ids[non_drop_indices].to(encoder_device) + nd_attn_mask = qwen3_attn_mask[non_drop_indices].to(encoder_device) elif len(non_drop_indices) == batch_size: - nd_input_ids = qwen3_input_ids - nd_attn_mask = qwen3_attn_mask + nd_input_ids = qwen3_input_ids.to(encoder_device) + nd_attn_mask = qwen3_attn_mask.to(encoder_device) else: nd_input_ids = None nd_attn_mask = None if nd_input_ids is not None: - nd_input_ids = nd_input_ids.to(qwen3_text_encoder.device if hasattr(qwen3_text_encoder, 'device') else next(qwen3_text_encoder.parameters()).device) - nd_attn_mask = nd_attn_mask.to(nd_input_ids.device) - outputs = qwen3_text_encoder(input_ids=nd_input_ids, attention_mask=nd_attn_mask) nd_encoded_text = outputs.last_hidden_state # Zero out padding positions nd_encoded_text[~nd_attn_mask.bool()] = 0 - # Fill back dropped items + # Build full batch: fill non-dropped with encoded, dropped with unconditional if len(non_drop_indices) == batch_size: prompt_embeds = nd_encoded_text - attn_mask = qwen3_attn_mask + attn_mask = qwen3_attn_mask.to(encoder_device) else: + # Get unconditional embeddings + if self._uncond_prompt_embeds is not None: + uncond_pe = self._uncond_prompt_embeds[0] + uncond_am = self._uncond_attn_mask[0] + uncond_t5_ids = self._uncond_t5_input_ids[0] + uncond_t5_am = self._uncond_t5_attn_mask[0] + else: + # Encode empty caption on-the-fly (text encoder still available) + uncond_tokens = tokenize_strategy.tokenize("") + uncond_ids = uncond_tokens[0].to(encoder_device) + uncond_mask = uncond_tokens[1].to(encoder_device) + uncond_out = qwen3_text_encoder(input_ids=uncond_ids, attention_mask=uncond_mask) + uncond_pe = uncond_out.last_hidden_state[0] + uncond_pe[~uncond_mask[0].bool()] = 0 + uncond_am = uncond_mask[0] + uncond_t5_ids = uncond_tokens[2][0] + uncond_t5_am = uncond_tokens[3][0] + seq_len = qwen3_input_ids.shape[1] - hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else 1024 - placeholder_dtype = nd_encoded_text.dtype if nd_encoded_text is not None else torch.float32 - device = nd_input_ids.device if nd_input_ids is not None else qwen3_input_ids.device - prompt_embeds = torch.zeros( - (batch_size, seq_len, hidden_size), device=device, dtype=placeholder_dtype - ) - attn_mask = torch.zeros( - (batch_size, seq_len), device=device, dtype=qwen3_attn_mask.dtype - ) + hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1] + dtype = nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype + + prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype) + attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype) + if len(non_drop_indices) > 0: prompt_embeds[non_drop_indices] = nd_encoded_text attn_mask[non_drop_indices] = nd_attn_mask - # Zero out t5_input_ids and t5_attn_mask for dropped items - # so the LLM adapter sees a consistent unconditional signal + # Fill dropped items with unconditional embeddings t5_input_ids = t5_input_ids.clone() t5_attn_mask = t5_attn_mask.clone() drop_indices = [i for i in range(batch_size) if i not in non_drop_indices] for i in drop_indices: - t5_input_ids[i] = torch.zeros_like(t5_input_ids[i]) - t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) + prompt_embeds[i] = uncond_pe.to(device=encoder_device, dtype=dtype) + attn_mask[i] = uncond_am.to(device=encoder_device, dtype=qwen3_attn_mask.dtype) + t5_input_ids[i] = uncond_t5_ids.to(device=t5_input_ids.device, dtype=t5_input_ids.dtype) + t5_attn_mask[i] = uncond_t5_am.to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype) return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask] @@ -178,7 +221,8 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): """Apply dropout to cached text encoder outputs. Called during training when using cached outputs. - Clones tensors to avoid corrupting cached data. + Replaces dropped items with pre-cached unconditional embeddings (from encoding "") + to match diffusion-pipe-main behavior. """ if prompt_embeds is not None and self.dropout_rate > 0.0: # Clone to avoid in-place modification of cached tensors @@ -192,13 +236,25 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): for i in range(prompt_embeds.shape[0]): if random.random() < self.dropout_rate: - prompt_embeds[i] = torch.zeros_like(prompt_embeds[i]) - if attn_mask is not None: - attn_mask[i] = torch.zeros_like(attn_mask[i]) - if t5_input_ids is not None: - t5_input_ids[i] = torch.zeros_like(t5_input_ids[i]) - if t5_attn_mask is not None: - t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) + if self._uncond_prompt_embeds is not None: + # Use pre-cached unconditional embeddings + prompt_embeds[i] = self._uncond_prompt_embeds[0].to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + if attn_mask is not None: + attn_mask[i] = self._uncond_attn_mask[0].to(device=attn_mask.device, dtype=attn_mask.dtype) + if t5_input_ids is not None: + t5_input_ids[i] = self._uncond_t5_input_ids[0].to(device=t5_input_ids.device, dtype=t5_input_ids.dtype) + if t5_attn_mask is not None: + t5_attn_mask[i] = self._uncond_t5_attn_mask[0].to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype) + else: + # Fallback: zero out (should not happen if cache_uncond_embeddings was called) + logger.warning("Unconditional embeddings not cached, falling back to zeros for caption dropout") + prompt_embeds[i] = torch.zeros_like(prompt_embeds[i]) + if attn_mask is not None: + attn_mask[i] = torch.zeros_like(attn_mask[i]) + if t5_input_ids is not None: + t5_input_ids[i] = torch.zeros_like(t5_input_ids[i]) + if t5_attn_mask is not None: + t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]