feat: simplify encode_tokens

This commit is contained in:
kohya-ss
2026-02-08 12:06:12 +09:00
parent 7b0ed3269a
commit 10445ff660

View File

@@ -45,8 +45,8 @@ class AnimaTokenizeStrategy(TokenizeStrategy):
t5_tokenizer = anima_utils.load_t5_tokenizer(t5_tokenizer_path)
self.qwen3_tokenizer = qwen3_tokenizer
self.t5_tokenizer = t5_tokenizer
self.qwen3_max_length = qwen3_max_length
self.t5_tokenizer = t5_tokenizer
self.t5_max_length = t5_max_length
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
@@ -140,77 +140,70 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
# Handle dropout: replace dropped items with unconditional embeddings (matching diffusion-pipe-main)
batch_size = qwen3_input_ids.shape[0]
non_drop_indices = []
for i in range(batch_size):
drop = enable_dropout and (self.dropout_rate > 0.0 and random.random() < self.dropout_rate)
if not drop:
non_drop_indices.append(i)
encoder_device = qwen3_text_encoder.device if hasattr(qwen3_text_encoder, 'device') else next(qwen3_text_encoder.parameters()).device
encoder_device = qwen3_text_encoder.device
if len(non_drop_indices) > 0 and len(non_drop_indices) < batch_size:
# Only encode non-dropped items to save compute
nd_input_ids = qwen3_input_ids[non_drop_indices].to(encoder_device)
nd_attn_mask = qwen3_attn_mask[non_drop_indices].to(encoder_device)
elif len(non_drop_indices) == batch_size:
nd_input_ids = qwen3_input_ids.to(encoder_device)
nd_attn_mask = qwen3_attn_mask.to(encoder_device)
# Build drop mask
if enable_dropout and self.dropout_rate > 0.0:
drop_mask = torch.rand(batch_size) < self.dropout_rate
else:
nd_input_ids = None
nd_attn_mask = None
drop_mask = torch.zeros(batch_size, dtype=torch.bool)
keep_mask = ~drop_mask
if nd_input_ids is not None:
# Encode only kept items
if keep_mask.any():
nd_input_ids = qwen3_input_ids[keep_mask].to(encoder_device)
nd_attn_mask = qwen3_attn_mask[keep_mask].to(encoder_device)
outputs = qwen3_text_encoder(input_ids=nd_input_ids, attention_mask=nd_attn_mask)
nd_encoded_text = outputs.last_hidden_state
# Zero out padding positions
nd_encoded_text[~nd_attn_mask.bool()] = 0
else:
nd_encoded_text = None
# Build full batch: fill non-dropped with encoded, dropped with unconditional
if len(non_drop_indices) == batch_size:
# If no items are dropped, return directly
if not drop_mask.any():
prompt_embeds = nd_encoded_text
attn_mask = qwen3_attn_mask.to(encoder_device)
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
# Get unconditional embeddings
if self._uncond_prompt_embeds is not None:
uncond_pe = self._uncond_prompt_embeds[0]
uncond_am = self._uncond_attn_mask[0]
uncond_t5_ids = self._uncond_t5_input_ids[0]
uncond_t5_am = self._uncond_t5_attn_mask[0]
else:
# Get unconditional embeddings
if self._uncond_prompt_embeds is not None:
uncond_pe = self._uncond_prompt_embeds[0]
uncond_am = self._uncond_attn_mask[0]
uncond_t5_ids = self._uncond_t5_input_ids[0]
uncond_t5_am = self._uncond_t5_attn_mask[0]
else:
# Encode empty caption on-the-fly (text encoder still available)
uncond_tokens = tokenize_strategy.tokenize("")
uncond_ids = uncond_tokens[0].to(encoder_device)
uncond_mask = uncond_tokens[1].to(encoder_device)
uncond_out = qwen3_text_encoder(input_ids=uncond_ids, attention_mask=uncond_mask)
uncond_pe = uncond_out.last_hidden_state[0]
uncond_pe[~uncond_mask[0].bool()] = 0
uncond_am = uncond_mask[0]
uncond_t5_ids = uncond_tokens[2][0]
uncond_t5_am = uncond_tokens[3][0]
# Recursively encode empty caption with no dropout
uncond_tokens = tokenize_strategy.tokenize("")
with torch.no_grad():
uncond_pe, uncond_am, uncond_t5_ids, uncond_t5_am = self.encode_tokens(
tokenize_strategy, models, uncond_tokens, enable_dropout=False
)
seq_len = qwen3_input_ids.shape[1]
hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1]
dtype = nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype
seq_len = qwen3_input_ids.shape[1]
hidden_size = (nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1])
dtype = (nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype)
prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype)
attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype)
prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype)
attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype)
if len(non_drop_indices) > 0:
prompt_embeds[non_drop_indices] = nd_encoded_text
attn_mask[non_drop_indices] = nd_attn_mask
if keep_mask.any():
prompt_embeds[keep_mask] = nd_encoded_text
attn_mask[keep_mask] = nd_attn_mask
# Fill dropped items with unconditional embeddings
t5_input_ids = t5_input_ids.clone()
t5_attn_mask = t5_attn_mask.clone()
drop_indices = [i for i in range(batch_size) if i not in non_drop_indices]
for i in drop_indices:
prompt_embeds[i] = uncond_pe.to(device=encoder_device, dtype=dtype)
attn_mask[i] = uncond_am.to(device=encoder_device, dtype=qwen3_attn_mask.dtype)
t5_input_ids[i] = uncond_t5_ids.to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
t5_attn_mask[i] = uncond_t5_am.to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
# Fill dropped items
prompt_embeds[drop_mask] = uncond_pe.to(device=encoder_device, dtype=dtype)
attn_mask[drop_mask] = uncond_am.to(device=encoder_device, dtype=qwen3_attn_mask.dtype)
t5_input_ids = t5_input_ids.clone()
t5_attn_mask = t5_attn_mask.clone()
t5_input_ids[drop_mask] = uncond_t5_ids.to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
t5_attn_mask[drop_mask] = uncond_t5_am.to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
def drop_cached_text_encoder_outputs(
self,
prompt_embeds: torch.Tensor,