mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
feat: kept caption dropout rate in cache and handle in training script
This commit is contained in:
@@ -255,10 +255,8 @@ def train(args):
|
||||
)
|
||||
|
||||
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
|
||||
caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0)
|
||||
if caption_dropout_rate > 0.0:
|
||||
with accelerator.autocast():
|
||||
text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
|
||||
with accelerator.autocast():
|
||||
text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
|
||||
@@ -231,12 +231,10 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
||||
|
||||
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
|
||||
caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0)
|
||||
text_encoding_strategy_for_uncond = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
if caption_dropout_rate > 0.0:
|
||||
tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy()
|
||||
with accelerator.autocast():
|
||||
text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders)
|
||||
tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy()
|
||||
with accelerator.autocast():
|
||||
text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -433,13 +431,21 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
# Text encoder conditions
|
||||
text_encoder_conds = []
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
anima_text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = text_encoding_strategy
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list
|
||||
caption_dropout_rates = text_encoder_outputs_list[-1]
|
||||
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
|
||||
|
||||
# Apply caption dropout to cached outputs
|
||||
text_encoder_conds = anima_text_encoding_strategy.drop_cached_text_encoder_outputs(
|
||||
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
|
||||
)
|
||||
|
||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
# TODO stop gradient for uncond embeddings when using caption dropout?
|
||||
encoded_text_encoder_conds = anima_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids,
|
||||
@@ -450,6 +456,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# Fill in only missing parts (partial caching)
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
@@ -54,22 +54,14 @@ class AnimaTokenizeStrategy(TokenizeStrategy):
|
||||
|
||||
# Tokenize with Qwen3
|
||||
qwen3_encoding = self.qwen3_tokenizer.batch_encode_plus(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.qwen3_max_length,
|
||||
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.qwen3_max_length
|
||||
)
|
||||
qwen3_input_ids = qwen3_encoding["input_ids"]
|
||||
qwen3_attn_mask = qwen3_encoding["attention_mask"]
|
||||
|
||||
# Tokenize with T5 (for LLM Adapter target tokens)
|
||||
t5_encoding = self.t5_tokenizer.batch_encode_plus(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.t5_max_length,
|
||||
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.t5_max_length
|
||||
)
|
||||
t5_input_ids = t5_encoding["input_ids"]
|
||||
t5_attn_mask = t5_encoding["attention_mask"]
|
||||
@@ -84,11 +76,9 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
T5 tokens are passed through unchanged (only used by LLM Adapter).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dropout_rate: float = 0.0,
|
||||
) -> None:
|
||||
self.dropout_rate = dropout_rate
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Cached unconditional embeddings (from encoding empty caption "")
|
||||
# Must be initialized via cache_uncond_embeddings() before text encoder is deleted
|
||||
self._uncond_prompt_embeds: Optional[torch.Tensor] = None # (1, seq_len, hidden)
|
||||
@@ -96,11 +86,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
self._uncond_t5_input_ids: Optional[torch.Tensor] = None # (1, t5_seq_len)
|
||||
self._uncond_t5_attn_mask: Optional[torch.Tensor] = None # (1, t5_seq_len)
|
||||
|
||||
def cache_uncond_embeddings(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
) -> None:
|
||||
def cache_uncond_embeddings(self, tokenize_strategy: TokenizeStrategy, models: List[Any]) -> None:
|
||||
"""Pre-encode empty caption "" and cache the unconditional embeddings.
|
||||
|
||||
Must be called before the text encoder is deleted from GPU.
|
||||
@@ -110,7 +96,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
logger.info("Caching unconditional embeddings for caption dropout (encoding empty caption)...")
|
||||
tokens = tokenize_strategy.tokenize("")
|
||||
with torch.no_grad():
|
||||
uncond_outputs = self.encode_tokens(tokenize_strategy, models, tokens, enable_dropout=False)
|
||||
uncond_outputs = self.encode_tokens(tokenize_strategy, models, tokens)
|
||||
# Store as CPU tensors (1, seq_len, ...) to avoid GPU memory waste
|
||||
self._uncond_prompt_embeds = uncond_outputs[0].cpu()
|
||||
self._uncond_attn_mask = uncond_outputs[1].cpu()
|
||||
@@ -119,11 +105,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
logger.info(" Unconditional embeddings cached successfully")
|
||||
|
||||
def encode_tokens(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: List[torch.Tensor],
|
||||
enable_dropout: bool = True,
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
"""Encode Qwen3 tokens and return embeddings + T5 token IDs.
|
||||
|
||||
@@ -134,74 +116,19 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
Returns:
|
||||
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
"""
|
||||
# Do not handle dropout here; handled dataset-side or in drop_cached_text_encoder_outputs()
|
||||
|
||||
qwen3_text_encoder = models[0]
|
||||
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
|
||||
|
||||
# Handle dropout: replace dropped items with unconditional embeddings (matching diffusion-pipe-main)
|
||||
batch_size = qwen3_input_ids.shape[0]
|
||||
|
||||
encoder_device = qwen3_text_encoder.device
|
||||
|
||||
# Build drop mask
|
||||
if enable_dropout and self.dropout_rate > 0.0:
|
||||
drop_mask = torch.rand(batch_size) < self.dropout_rate
|
||||
else:
|
||||
drop_mask = torch.zeros(batch_size, dtype=torch.bool)
|
||||
keep_mask = ~drop_mask
|
||||
qwen3_input_ids = qwen3_input_ids.to(encoder_device)
|
||||
qwen3_attn_mask = qwen3_attn_mask.to(encoder_device)
|
||||
outputs = qwen3_text_encoder(input_ids=qwen3_input_ids, attention_mask=qwen3_attn_mask)
|
||||
prompt_embeds = outputs.last_hidden_state
|
||||
|
||||
# Encode only kept items
|
||||
if keep_mask.any():
|
||||
nd_input_ids = qwen3_input_ids[keep_mask].to(encoder_device)
|
||||
nd_attn_mask = qwen3_attn_mask[keep_mask].to(encoder_device)
|
||||
outputs = qwen3_text_encoder(input_ids=nd_input_ids, attention_mask=nd_attn_mask)
|
||||
nd_encoded_text = outputs.last_hidden_state
|
||||
# Zero out padding positions
|
||||
nd_encoded_text[~nd_attn_mask.bool()] = 0
|
||||
else:
|
||||
nd_encoded_text = None
|
||||
|
||||
# If no items are dropped, return directly
|
||||
if not drop_mask.any():
|
||||
prompt_embeds = nd_encoded_text
|
||||
attn_mask = qwen3_attn_mask.to(encoder_device)
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
# Get unconditional embeddings
|
||||
if self._uncond_prompt_embeds is not None:
|
||||
uncond_pe = self._uncond_prompt_embeds[0]
|
||||
uncond_am = self._uncond_attn_mask[0]
|
||||
uncond_t5_ids = self._uncond_t5_input_ids[0]
|
||||
uncond_t5_am = self._uncond_t5_attn_mask[0]
|
||||
else:
|
||||
# Recursively encode empty caption with no dropout
|
||||
uncond_tokens = tokenize_strategy.tokenize("")
|
||||
with torch.no_grad():
|
||||
uncond_pe, uncond_am, uncond_t5_ids, uncond_t5_am = self.encode_tokens(
|
||||
tokenize_strategy, models, uncond_tokens, enable_dropout=False
|
||||
)
|
||||
|
||||
seq_len = qwen3_input_ids.shape[1]
|
||||
hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1]
|
||||
dtype = nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype
|
||||
|
||||
prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype)
|
||||
attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype)
|
||||
|
||||
if keep_mask.any():
|
||||
prompt_embeds[keep_mask] = nd_encoded_text
|
||||
attn_mask[keep_mask] = nd_attn_mask
|
||||
|
||||
# Fill dropped items
|
||||
prompt_embeds[drop_mask] = uncond_pe.to(device=encoder_device, dtype=dtype)
|
||||
attn_mask[drop_mask] = uncond_am.to(device=encoder_device, dtype=qwen3_attn_mask.dtype)
|
||||
|
||||
t5_input_ids = t5_input_ids.clone()
|
||||
t5_attn_mask = t5_attn_mask.clone()
|
||||
t5_input_ids[drop_mask] = uncond_t5_ids.to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
|
||||
t5_attn_mask[drop_mask] = uncond_t5_am.to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
|
||||
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
return [prompt_embeds, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
def drop_cached_text_encoder_outputs(
|
||||
self,
|
||||
@@ -209,6 +136,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
attn_mask: torch.Tensor,
|
||||
t5_input_ids: torch.Tensor,
|
||||
t5_attn_mask: torch.Tensor,
|
||||
caption_dropout_rates: Optional[torch.Tensor] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Apply dropout to cached text encoder outputs.
|
||||
|
||||
@@ -216,37 +144,30 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
Replaces dropped items with pre-cached unconditional embeddings (from encoding "")
|
||||
to match diffusion-pipe-main behavior.
|
||||
"""
|
||||
if prompt_embeds is not None and self.dropout_rate > 0.0:
|
||||
# Clone to avoid in-place modification of cached tensors
|
||||
prompt_embeds = prompt_embeds.clone()
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.clone()
|
||||
if t5_input_ids is not None:
|
||||
t5_input_ids = t5_input_ids.clone()
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask = t5_attn_mask.clone()
|
||||
if caption_dropout_rates is None or all(caption_dropout_rates == 0.0):
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
for i in range(prompt_embeds.shape[0]):
|
||||
if random.random() < self.dropout_rate:
|
||||
if self._uncond_prompt_embeds is not None:
|
||||
# Use pre-cached unconditional embeddings
|
||||
prompt_embeds[i] = self._uncond_prompt_embeds[0].to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
||||
if attn_mask is not None:
|
||||
attn_mask[i] = self._uncond_attn_mask[0].to(device=attn_mask.device, dtype=attn_mask.dtype)
|
||||
if t5_input_ids is not None:
|
||||
t5_input_ids[i] = self._uncond_t5_input_ids[0].to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask[i] = self._uncond_t5_attn_mask[0].to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
|
||||
else:
|
||||
# Fallback: zero out (should not happen if cache_uncond_embeddings was called)
|
||||
logger.warning("Unconditional embeddings not cached, falling back to zeros for caption dropout")
|
||||
prompt_embeds[i] = torch.zeros_like(prompt_embeds[i])
|
||||
if attn_mask is not None:
|
||||
attn_mask[i] = torch.zeros_like(attn_mask[i])
|
||||
if t5_input_ids is not None:
|
||||
t5_input_ids[i] = torch.zeros_like(t5_input_ids[i])
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
|
||||
assert self._uncond_prompt_embeds is not None, "Unconditional embeddings not cached, cannot apply caption dropout"
|
||||
|
||||
# Clone to avoid in-place modification of cached tensors
|
||||
prompt_embeds = prompt_embeds.clone()
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.clone()
|
||||
if t5_input_ids is not None:
|
||||
t5_input_ids = t5_input_ids.clone()
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask = t5_attn_mask.clone()
|
||||
|
||||
for i in range(prompt_embeds.shape[0]):
|
||||
if random.random() < caption_dropout_rates[i].item():
|
||||
# Use pre-cached unconditional embeddings
|
||||
prompt_embeds[i] = self._uncond_prompt_embeds[0].to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
||||
if attn_mask is not None:
|
||||
attn_mask[i] = self._uncond_attn_mask[0].to(device=attn_mask.device, dtype=attn_mask.dtype)
|
||||
if t5_input_ids is not None:
|
||||
t5_input_ids[i] = self._uncond_t5_input_ids[0].to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask[i] = self._uncond_t5_attn_mask[0].to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
|
||||
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
@@ -289,6 +210,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
if "caption_dropout_rate" not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
@@ -301,7 +224,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
attn_mask = data["attn_mask"]
|
||||
t5_input_ids = data["t5_input_ids"]
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
caption_dropout_rate = data["caption_dropout_rate"]
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self,
|
||||
@@ -336,6 +260,7 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
attn_mask_i = attn_mask[i]
|
||||
t5_input_ids_i = t5_input_ids[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
@@ -344,9 +269,16 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
attn_mask=attn_mask_i,
|
||||
t5_input_ids=t5_input_ids_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
caption_dropout_rate=caption_dropout_rate,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i)
|
||||
info.text_encoder_outputs = (
|
||||
prompt_embeds_i,
|
||||
attn_mask_i,
|
||||
t5_input_ids_i,
|
||||
t5_attn_mask_i,
|
||||
caption_dropout_rate,
|
||||
)
|
||||
|
||||
|
||||
class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
|
||||
@@ -179,12 +179,15 @@ def split_train_val(
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||
def __init__(
|
||||
self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str, caption_dropout_rate: float = 0.0
|
||||
) -> None:
|
||||
self.image_key: str = image_key
|
||||
self.num_repeats: int = num_repeats
|
||||
self.caption: str = caption
|
||||
self.is_reg: bool = is_reg
|
||||
self.absolute_path: str = absolute_path
|
||||
self.caption_dropout_rate: float = caption_dropout_rate
|
||||
self.image_size: Tuple[int, int] = None
|
||||
self.resized_size: Tuple[int, int] = None
|
||||
self.bucket_reso: Tuple[int, int] = None
|
||||
@@ -197,7 +200,7 @@ class ImageInfo:
|
||||
)
|
||||
self.cond_img_path: Optional[str] = None
|
||||
self.image: Optional[Image.Image] = None # optional, original PIL Image
|
||||
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
|
||||
self.text_encoder_outputs_npz: Optional[str] = None # filename. set in cache_text_encoder_outputs
|
||||
|
||||
# new
|
||||
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
|
||||
@@ -2137,7 +2140,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
num_train_images += num_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path, subset.caption_dropout_rate)
|
||||
info.resize_interpolation = (
|
||||
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
)
|
||||
@@ -2338,7 +2341,7 @@ class FineTuningDataset(BaseDataset):
|
||||
if caption is None:
|
||||
caption = ""
|
||||
|
||||
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
|
||||
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path, subset.caption_dropout_rate)
|
||||
image_info.resize_interpolation = (
|
||||
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user