mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
feat: simplify encode_tokens
This commit is contained in:
@@ -45,8 +45,8 @@ class AnimaTokenizeStrategy(TokenizeStrategy):
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(t5_tokenizer_path)
|
||||
|
||||
self.qwen3_tokenizer = qwen3_tokenizer
|
||||
self.t5_tokenizer = t5_tokenizer
|
||||
self.qwen3_max_length = qwen3_max_length
|
||||
self.t5_tokenizer = t5_tokenizer
|
||||
self.t5_max_length = t5_max_length
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
@@ -140,77 +140,70 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
# Handle dropout: replace dropped items with unconditional embeddings (matching diffusion-pipe-main)
|
||||
batch_size = qwen3_input_ids.shape[0]
|
||||
non_drop_indices = []
|
||||
for i in range(batch_size):
|
||||
drop = enable_dropout and (self.dropout_rate > 0.0 and random.random() < self.dropout_rate)
|
||||
if not drop:
|
||||
non_drop_indices.append(i)
|
||||
|
||||
encoder_device = qwen3_text_encoder.device if hasattr(qwen3_text_encoder, 'device') else next(qwen3_text_encoder.parameters()).device
|
||||
encoder_device = qwen3_text_encoder.device
|
||||
|
||||
if len(non_drop_indices) > 0 and len(non_drop_indices) < batch_size:
|
||||
# Only encode non-dropped items to save compute
|
||||
nd_input_ids = qwen3_input_ids[non_drop_indices].to(encoder_device)
|
||||
nd_attn_mask = qwen3_attn_mask[non_drop_indices].to(encoder_device)
|
||||
elif len(non_drop_indices) == batch_size:
|
||||
nd_input_ids = qwen3_input_ids.to(encoder_device)
|
||||
nd_attn_mask = qwen3_attn_mask.to(encoder_device)
|
||||
# Build drop mask
|
||||
if enable_dropout and self.dropout_rate > 0.0:
|
||||
drop_mask = torch.rand(batch_size) < self.dropout_rate
|
||||
else:
|
||||
nd_input_ids = None
|
||||
nd_attn_mask = None
|
||||
drop_mask = torch.zeros(batch_size, dtype=torch.bool)
|
||||
keep_mask = ~drop_mask
|
||||
|
||||
if nd_input_ids is not None:
|
||||
# Encode only kept items
|
||||
if keep_mask.any():
|
||||
nd_input_ids = qwen3_input_ids[keep_mask].to(encoder_device)
|
||||
nd_attn_mask = qwen3_attn_mask[keep_mask].to(encoder_device)
|
||||
outputs = qwen3_text_encoder(input_ids=nd_input_ids, attention_mask=nd_attn_mask)
|
||||
nd_encoded_text = outputs.last_hidden_state
|
||||
# Zero out padding positions
|
||||
nd_encoded_text[~nd_attn_mask.bool()] = 0
|
||||
else:
|
||||
nd_encoded_text = None
|
||||
|
||||
# Build full batch: fill non-dropped with encoded, dropped with unconditional
|
||||
if len(non_drop_indices) == batch_size:
|
||||
# If no items are dropped, return directly
|
||||
if not drop_mask.any():
|
||||
prompt_embeds = nd_encoded_text
|
||||
attn_mask = qwen3_attn_mask.to(encoder_device)
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
# Get unconditional embeddings
|
||||
if self._uncond_prompt_embeds is not None:
|
||||
uncond_pe = self._uncond_prompt_embeds[0]
|
||||
uncond_am = self._uncond_attn_mask[0]
|
||||
uncond_t5_ids = self._uncond_t5_input_ids[0]
|
||||
uncond_t5_am = self._uncond_t5_attn_mask[0]
|
||||
else:
|
||||
# Get unconditional embeddings
|
||||
if self._uncond_prompt_embeds is not None:
|
||||
uncond_pe = self._uncond_prompt_embeds[0]
|
||||
uncond_am = self._uncond_attn_mask[0]
|
||||
uncond_t5_ids = self._uncond_t5_input_ids[0]
|
||||
uncond_t5_am = self._uncond_t5_attn_mask[0]
|
||||
else:
|
||||
# Encode empty caption on-the-fly (text encoder still available)
|
||||
uncond_tokens = tokenize_strategy.tokenize("")
|
||||
uncond_ids = uncond_tokens[0].to(encoder_device)
|
||||
uncond_mask = uncond_tokens[1].to(encoder_device)
|
||||
uncond_out = qwen3_text_encoder(input_ids=uncond_ids, attention_mask=uncond_mask)
|
||||
uncond_pe = uncond_out.last_hidden_state[0]
|
||||
uncond_pe[~uncond_mask[0].bool()] = 0
|
||||
uncond_am = uncond_mask[0]
|
||||
uncond_t5_ids = uncond_tokens[2][0]
|
||||
uncond_t5_am = uncond_tokens[3][0]
|
||||
# Recursively encode empty caption with no dropout
|
||||
uncond_tokens = tokenize_strategy.tokenize("")
|
||||
with torch.no_grad():
|
||||
uncond_pe, uncond_am, uncond_t5_ids, uncond_t5_am = self.encode_tokens(
|
||||
tokenize_strategy, models, uncond_tokens, enable_dropout=False
|
||||
)
|
||||
|
||||
seq_len = qwen3_input_ids.shape[1]
|
||||
hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1]
|
||||
dtype = nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype
|
||||
seq_len = qwen3_input_ids.shape[1]
|
||||
hidden_size = (nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1])
|
||||
dtype = (nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype)
|
||||
|
||||
prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype)
|
||||
attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype)
|
||||
prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype)
|
||||
attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype)
|
||||
|
||||
if len(non_drop_indices) > 0:
|
||||
prompt_embeds[non_drop_indices] = nd_encoded_text
|
||||
attn_mask[non_drop_indices] = nd_attn_mask
|
||||
if keep_mask.any():
|
||||
prompt_embeds[keep_mask] = nd_encoded_text
|
||||
attn_mask[keep_mask] = nd_attn_mask
|
||||
|
||||
# Fill dropped items with unconditional embeddings
|
||||
t5_input_ids = t5_input_ids.clone()
|
||||
t5_attn_mask = t5_attn_mask.clone()
|
||||
drop_indices = [i for i in range(batch_size) if i not in non_drop_indices]
|
||||
for i in drop_indices:
|
||||
prompt_embeds[i] = uncond_pe.to(device=encoder_device, dtype=dtype)
|
||||
attn_mask[i] = uncond_am.to(device=encoder_device, dtype=qwen3_attn_mask.dtype)
|
||||
t5_input_ids[i] = uncond_t5_ids.to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
|
||||
t5_attn_mask[i] = uncond_t5_am.to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
|
||||
# Fill dropped items
|
||||
prompt_embeds[drop_mask] = uncond_pe.to(device=encoder_device, dtype=dtype)
|
||||
attn_mask[drop_mask] = uncond_am.to(device=encoder_device, dtype=qwen3_attn_mask.dtype)
|
||||
|
||||
t5_input_ids = t5_input_ids.clone()
|
||||
t5_attn_mask = t5_attn_mask.clone()
|
||||
t5_input_ids[drop_mask] = uncond_t5_ids.to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
|
||||
t5_attn_mask[drop_mask] = uncond_t5_am.to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
|
||||
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
|
||||
def drop_cached_text_encoder_outputs(
|
||||
self,
|
||||
prompt_embeds: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user