feat: kept caption dropout rate in cache and handle in training script

This commit is contained in:
kohya-ss
2026-02-08 15:35:53 +09:00
parent c3556d455f
commit 4f6511bf28
4 changed files with 74 additions and 134 deletions

View File

@@ -255,10 +255,8 @@ def train(args):
)
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0)
if caption_dropout_rate > 0.0:
with accelerator.autocast():
text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
with accelerator.autocast():
text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
accelerator.wait_for_everyone()

View File

@@ -231,12 +231,10 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
self.sample_prompts_te_outputs = sample_prompts_te_outputs
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0)
text_encoding_strategy_for_uncond = strategy_base.TextEncodingStrategy.get_strategy()
if caption_dropout_rate > 0.0:
tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy()
with accelerator.autocast():
text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders)
tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy()
with accelerator.autocast():
text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders)
accelerator.wait_for_everyone()
@@ -433,13 +431,21 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
# Text encoder conditions
text_encoder_conds = []
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
anima_text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = text_encoding_strategy
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list
caption_dropout_rates = text_encoder_outputs_list[-1]
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
# Apply caption dropout to cached outputs
text_encoder_conds = anima_text_encoding_strategy.drop_cached_text_encoder_outputs(
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
)
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
# TODO stop gradient for uncond embeddings when using caption dropout?
encoded_text_encoder_conds = anima_text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
@@ -450,6 +456,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# Fill in only missing parts (partial caching)
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]

View File

@@ -54,22 +54,14 @@ class AnimaTokenizeStrategy(TokenizeStrategy):
# Tokenize with Qwen3
qwen3_encoding = self.qwen3_tokenizer.batch_encode_plus(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=self.qwen3_max_length,
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.qwen3_max_length
)
qwen3_input_ids = qwen3_encoding["input_ids"]
qwen3_attn_mask = qwen3_encoding["attention_mask"]
# Tokenize with T5 (for LLM Adapter target tokens)
t5_encoding = self.t5_tokenizer.batch_encode_plus(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=self.t5_max_length,
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.t5_max_length
)
t5_input_ids = t5_encoding["input_ids"]
t5_attn_mask = t5_encoding["attention_mask"]
@@ -84,11 +76,9 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
T5 tokens are passed through unchanged (only used by LLM Adapter).
"""
def __init__(
self,
dropout_rate: float = 0.0,
) -> None:
self.dropout_rate = dropout_rate
def __init__(self) -> None:
super().__init__()
# Cached unconditional embeddings (from encoding empty caption "")
# Must be initialized via cache_uncond_embeddings() before text encoder is deleted
self._uncond_prompt_embeds: Optional[torch.Tensor] = None # (1, seq_len, hidden)
@@ -96,11 +86,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
self._uncond_t5_input_ids: Optional[torch.Tensor] = None # (1, t5_seq_len)
self._uncond_t5_attn_mask: Optional[torch.Tensor] = None # (1, t5_seq_len)
def cache_uncond_embeddings(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
) -> None:
def cache_uncond_embeddings(self, tokenize_strategy: TokenizeStrategy, models: List[Any]) -> None:
"""Pre-encode empty caption "" and cache the unconditional embeddings.
Must be called before the text encoder is deleted from GPU.
@@ -110,7 +96,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
logger.info("Caching unconditional embeddings for caption dropout (encoding empty caption)...")
tokens = tokenize_strategy.tokenize("")
with torch.no_grad():
uncond_outputs = self.encode_tokens(tokenize_strategy, models, tokens, enable_dropout=False)
uncond_outputs = self.encode_tokens(tokenize_strategy, models, tokens)
# Store as CPU tensors (1, seq_len, ...) to avoid GPU memory waste
self._uncond_prompt_embeds = uncond_outputs[0].cpu()
self._uncond_attn_mask = uncond_outputs[1].cpu()
@@ -119,11 +105,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
logger.info(" Unconditional embeddings cached successfully")
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
enable_dropout: bool = True,
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""Encode Qwen3 tokens and return embeddings + T5 token IDs.
@@ -134,74 +116,19 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
Returns:
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
"""
# Do not handle dropout here; handled dataset-side or in drop_cached_text_encoder_outputs()
qwen3_text_encoder = models[0]
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
# Handle dropout: replace dropped items with unconditional embeddings (matching diffusion-pipe-main)
batch_size = qwen3_input_ids.shape[0]
encoder_device = qwen3_text_encoder.device
# Build drop mask
if enable_dropout and self.dropout_rate > 0.0:
drop_mask = torch.rand(batch_size) < self.dropout_rate
else:
drop_mask = torch.zeros(batch_size, dtype=torch.bool)
keep_mask = ~drop_mask
qwen3_input_ids = qwen3_input_ids.to(encoder_device)
qwen3_attn_mask = qwen3_attn_mask.to(encoder_device)
outputs = qwen3_text_encoder(input_ids=qwen3_input_ids, attention_mask=qwen3_attn_mask)
prompt_embeds = outputs.last_hidden_state
# Encode only kept items
if keep_mask.any():
nd_input_ids = qwen3_input_ids[keep_mask].to(encoder_device)
nd_attn_mask = qwen3_attn_mask[keep_mask].to(encoder_device)
outputs = qwen3_text_encoder(input_ids=nd_input_ids, attention_mask=nd_attn_mask)
nd_encoded_text = outputs.last_hidden_state
# Zero out padding positions
nd_encoded_text[~nd_attn_mask.bool()] = 0
else:
nd_encoded_text = None
# If no items are dropped, return directly
if not drop_mask.any():
prompt_embeds = nd_encoded_text
attn_mask = qwen3_attn_mask.to(encoder_device)
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
# Get unconditional embeddings
if self._uncond_prompt_embeds is not None:
uncond_pe = self._uncond_prompt_embeds[0]
uncond_am = self._uncond_attn_mask[0]
uncond_t5_ids = self._uncond_t5_input_ids[0]
uncond_t5_am = self._uncond_t5_attn_mask[0]
else:
# Recursively encode empty caption with no dropout
uncond_tokens = tokenize_strategy.tokenize("")
with torch.no_grad():
uncond_pe, uncond_am, uncond_t5_ids, uncond_t5_am = self.encode_tokens(
tokenize_strategy, models, uncond_tokens, enable_dropout=False
)
seq_len = qwen3_input_ids.shape[1]
hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1]
dtype = nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype
prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype)
attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype)
if keep_mask.any():
prompt_embeds[keep_mask] = nd_encoded_text
attn_mask[keep_mask] = nd_attn_mask
# Fill dropped items
prompt_embeds[drop_mask] = uncond_pe.to(device=encoder_device, dtype=dtype)
attn_mask[drop_mask] = uncond_am.to(device=encoder_device, dtype=qwen3_attn_mask.dtype)
t5_input_ids = t5_input_ids.clone()
t5_attn_mask = t5_attn_mask.clone()
t5_input_ids[drop_mask] = uncond_t5_ids.to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
t5_attn_mask[drop_mask] = uncond_t5_am.to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
return [prompt_embeds, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
def drop_cached_text_encoder_outputs(
self,
@@ -209,6 +136,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
attn_mask: torch.Tensor,
t5_input_ids: torch.Tensor,
t5_attn_mask: torch.Tensor,
caption_dropout_rates: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
"""Apply dropout to cached text encoder outputs.
@@ -216,37 +144,30 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
Replaces dropped items with pre-cached unconditional embeddings (from encoding "")
to match diffusion-pipe-main behavior.
"""
if prompt_embeds is not None and self.dropout_rate > 0.0:
# Clone to avoid in-place modification of cached tensors
prompt_embeds = prompt_embeds.clone()
if attn_mask is not None:
attn_mask = attn_mask.clone()
if t5_input_ids is not None:
t5_input_ids = t5_input_ids.clone()
if t5_attn_mask is not None:
t5_attn_mask = t5_attn_mask.clone()
if caption_dropout_rates is None or all(caption_dropout_rates == 0.0):
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
for i in range(prompt_embeds.shape[0]):
if random.random() < self.dropout_rate:
if self._uncond_prompt_embeds is not None:
# Use pre-cached unconditional embeddings
prompt_embeds[i] = self._uncond_prompt_embeds[0].to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if attn_mask is not None:
attn_mask[i] = self._uncond_attn_mask[0].to(device=attn_mask.device, dtype=attn_mask.dtype)
if t5_input_ids is not None:
t5_input_ids[i] = self._uncond_t5_input_ids[0].to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
if t5_attn_mask is not None:
t5_attn_mask[i] = self._uncond_t5_attn_mask[0].to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
else:
# Fallback: zero out (should not happen if cache_uncond_embeddings was called)
logger.warning("Unconditional embeddings not cached, falling back to zeros for caption dropout")
prompt_embeds[i] = torch.zeros_like(prompt_embeds[i])
if attn_mask is not None:
attn_mask[i] = torch.zeros_like(attn_mask[i])
if t5_input_ids is not None:
t5_input_ids[i] = torch.zeros_like(t5_input_ids[i])
if t5_attn_mask is not None:
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
assert self._uncond_prompt_embeds is not None, "Unconditional embeddings not cached, cannot apply caption dropout"
# Clone to avoid in-place modification of cached tensors
prompt_embeds = prompt_embeds.clone()
if attn_mask is not None:
attn_mask = attn_mask.clone()
if t5_input_ids is not None:
t5_input_ids = t5_input_ids.clone()
if t5_attn_mask is not None:
t5_attn_mask = t5_attn_mask.clone()
for i in range(prompt_embeds.shape[0]):
if random.random() < caption_dropout_rates[i].item():
# Use pre-cached unconditional embeddings
prompt_embeds[i] = self._uncond_prompt_embeds[0].to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if attn_mask is not None:
attn_mask[i] = self._uncond_attn_mask[0].to(device=attn_mask.device, dtype=attn_mask.dtype)
if t5_input_ids is not None:
t5_input_ids[i] = self._uncond_t5_input_ids[0].to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
if t5_attn_mask is not None:
t5_attn_mask[i] = self._uncond_t5_attn_mask[0].to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
@@ -289,6 +210,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return False
if "t5_attn_mask" not in npz:
return False
if "caption_dropout_rate" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -301,7 +224,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask = data["attn_mask"]
t5_input_ids = data["t5_input_ids"]
t5_attn_mask = data["t5_attn_mask"]
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
caption_dropout_rate = data["caption_dropout_rate"]
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
def cache_batch_outputs(
self,
@@ -336,6 +260,7 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask_i = attn_mask[i]
t5_input_ids_i = t5_input_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
if self.cache_to_disk:
np.savez(
@@ -344,9 +269,16 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask=attn_mask_i,
t5_input_ids=t5_input_ids_i,
t5_attn_mask=t5_attn_mask_i,
caption_dropout_rate=caption_dropout_rate,
)
else:
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i)
info.text_encoder_outputs = (
prompt_embeds_i,
attn_mask_i,
t5_input_ids_i,
t5_attn_mask_i,
caption_dropout_rate,
)
class AnimaLatentsCachingStrategy(LatentsCachingStrategy):

View File

@@ -179,12 +179,15 @@ def split_train_val(
class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
def __init__(
self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str, caption_dropout_rate: float = 0.0
) -> None:
self.image_key: str = image_key
self.num_repeats: int = num_repeats
self.caption: str = caption
self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path
self.caption_dropout_rate: float = caption_dropout_rate
self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None
@@ -197,7 +200,7 @@ class ImageInfo:
)
self.cond_img_path: Optional[str] = None
self.image: Optional[Image.Image] = None # optional, original PIL Image
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
self.text_encoder_outputs_npz: Optional[str] = None # filename. set in cache_text_encoder_outputs
# new
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
@@ -2137,7 +2140,7 @@ class DreamBoothDataset(BaseDataset):
num_train_images += num_repeats * len(img_paths)
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path, subset.caption_dropout_rate)
info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
)
@@ -2338,7 +2341,7 @@ class FineTuningDataset(BaseDataset):
if caption is None:
caption = ""
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path, subset.caption_dropout_rate)
image_info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
)