Add dropout rate arguments for CLIP-L, CLIP-G, and T5, fix Text Encoders LoRA not trained

This commit is contained in:
Kohya S
2024-10-27 16:42:58 +09:00
parent b649bbf2b6
commit db2b4d41b9
5 changed files with 138 additions and 17 deletions

View File

@@ -214,6 +214,24 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
action="store_true", action="store_true",
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスクゼロ埋めを適用する", help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスクゼロ埋めを適用する",
) )
parser.add_argument(
"--clip_l_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for CLIP-L encoder, default is 0.0 / CLIP-Lエンコーダのドロップアウト率、デフォルトは0.0",
)
parser.add_argument(
"--clip_g_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for CLIP-G encoder, default is 0.0 / CLIP-Gエンコーダのドロップアウト率、デフォルトは0.0",
)
parser.add_argument(
"--t5_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0",
)
# copy from Diffusers # copy from Diffusers
parser.add_argument( parser.add_argument(

View File

@@ -1,5 +1,6 @@
import os import os
import glob import glob
import random
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
import torch import torch
import numpy as np import numpy as np
@@ -48,13 +49,23 @@ class Sd3TokenizeStrategy(TokenizeStrategy):
class Sd3TextEncodingStrategy(TextEncodingStrategy): class Sd3TextEncodingStrategy(TextEncodingStrategy):
def __init__(self, apply_lg_attn_mask: Optional[bool] = None, apply_t5_attn_mask: Optional[bool] = None) -> None: def __init__(
self,
apply_lg_attn_mask: Optional[bool] = None,
apply_t5_attn_mask: Optional[bool] = None,
l_dropout_rate: float = 0.0,
g_dropout_rate: float = 0.0,
t5_dropout_rate: float = 0.0,
) -> None:
""" """
Args: Args:
apply_t5_attn_mask: Default value for apply_t5_attn_mask. apply_t5_attn_mask: Default value for apply_t5_attn_mask.
""" """
self.apply_lg_attn_mask = apply_lg_attn_mask self.apply_lg_attn_mask = apply_lg_attn_mask
self.apply_t5_attn_mask = apply_t5_attn_mask self.apply_t5_attn_mask = apply_t5_attn_mask
self.l_dropout_rate = l_dropout_rate
self.g_dropout_rate = g_dropout_rate
self.t5_dropout_rate = t5_dropout_rate
def encode_tokens( def encode_tokens(
self, self,
@@ -63,6 +74,7 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
tokens: List[torch.Tensor], tokens: List[torch.Tensor],
apply_lg_attn_mask: Optional[bool] = False, apply_lg_attn_mask: Optional[bool] = False,
apply_t5_attn_mask: Optional[bool] = False, apply_t5_attn_mask: Optional[bool] = False,
enable_dropout: bool = True,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
""" """
returned embeddings are not masked returned embeddings are not masked
@@ -91,37 +103,92 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
g_attn_mask = None g_attn_mask = None
t5_attn_mask = None t5_attn_mask = None
# dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings
if l_tokens is None or clip_l is None: if l_tokens is None or clip_l is None:
assert g_tokens is None, "g_tokens must be None if l_tokens is None" assert g_tokens is None, "g_tokens must be None if l_tokens is None"
lg_out = None lg_out = None
lg_pooled = None lg_pooled = None
else: else:
with torch.no_grad(): assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None
g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None
drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
if drop_l:
l_pooled = torch.zeros((l_tokens.shape[0], 768), device=l_tokens.device, dtype=l_tokens.dtype)
l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=l_tokens.device, dtype=l_tokens.dtype)
if l_attn_mask is not None:
l_attn_mask = torch.zeros_like(l_attn_mask)
else:
l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None
prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True) prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True)
l_pooled = prompt_embeds[0] l_pooled = prompt_embeds[0]
l_out = prompt_embeds.hidden_states[-2] l_out = prompt_embeds.hidden_states[-2]
drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
if drop_g:
g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=g_tokens.device, dtype=g_tokens.dtype)
g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=g_tokens.device, dtype=g_tokens.dtype)
if g_attn_mask is not None:
g_attn_mask = torch.zeros_like(g_attn_mask)
else:
g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None
prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True) prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True)
g_pooled = prompt_embeds[0] g_pooled = prompt_embeds[0]
g_out = prompt_embeds.hidden_states[-2] g_out = prompt_embeds.hidden_states[-2]
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
lg_out = torch.cat([l_out, g_out], dim=-1) lg_out = torch.cat([l_out, g_out], dim=-1)
if t5xxl is None or t5_tokens is None: if t5xxl is None or t5_tokens is None:
t5_out = None t5_out = None
else: else:
t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
with torch.no_grad(): if drop_t5:
t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5_tokens.device, dtype=t5_tokens.dtype)
if t5_attn_mask is not None:
t5_attn_mask = torch.zeros_like(t5_attn_mask)
else:
t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True) t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True)
# masks are used for attention masking in transformer # masks are used for attention masking in transformer
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def drop_cached_text_encoder_outputs(
self,
lg_out: torch.Tensor,
t5_out: torch.Tensor,
lg_pooled: torch.Tensor,
l_attn_mask: torch.Tensor,
g_attn_mask: torch.Tensor,
t5_attn_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings
if lg_out is not None:
for i in range(lg_out.shape[0]):
drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate
if drop_l:
lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768])
lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768])
if l_attn_mask is not None:
l_attn_mask[i] = torch.zeros_like(l_attn_mask[i])
drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate
if drop_g:
lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:])
lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:])
if g_attn_mask is not None:
g_attn_mask[i] = torch.zeros_like(g_attn_mask[i])
if t5_out is not None:
for i in range(t5_out.shape[0]):
drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate
if drop_t5:
t5_out[i] = torch.zeros_like(t5_out[i])
if t5_attn_mask is not None:
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
return lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask
def concat_encodings( def concat_encodings(
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -207,8 +274,14 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokens_and_masks = tokenize_strategy.tokenize(captions) tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad(): with torch.no_grad():
# always disable dropout during caching
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens( lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask tokenize_strategy,
models,
tokens_and_masks,
apply_lg_attn_mask=self.apply_lg_attn_mask,
apply_t5_attn_mask=self.apply_t5_attn_mask,
enable_dropout=False,
) )
if lg_out.dtype == torch.bfloat16: if lg_out.dtype == torch.bfloat16:

View File

@@ -69,6 +69,11 @@ def train(args):
# assert ( # assert (
# not args.train_text_encoder or not args.cache_text_encoder_outputs # not args.train_text_encoder or not args.cache_text_encoder_outputs
# ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
)
args.cache_text_encoder_outputs = True
assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), ( assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), (
"when training text encoder, text encoder outputs must not be cached (except for T5XXL)" "when training text encoder, text encoder outputs must not be cached (except for T5XXL)"
@@ -232,7 +237,9 @@ def train(args):
assert clip_l is not None and clip_g is not None and t5xxl is not None, "clip_l, clip_g, t5xxl must be specified" assert clip_l is not None and clip_g is not None and t5xxl is not None, "clip_l, clip_g, t5xxl must be specified"
# prepare text encoding strategy # prepare text encoding strategy
text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy(args.apply_lg_attn_mask, args.apply_t5_attn_mask) text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy(
args.apply_lg_attn_mask, args.apply_t5_attn_mask, args.clip_l_dropout_rate, args.clip_g_dropout_rate, args.t5_dropout_rate
)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# 学習を準備する:モデルを適切な状態にする # 学習を準備する:モデルを適切な状態にする
@@ -311,6 +318,7 @@ def train(args):
tokens_and_masks, tokens_and_masks,
args.apply_lg_attn_mask, args.apply_lg_attn_mask,
args.apply_t5_attn_mask, args.apply_t5_attn_mask,
enable_dropout=False,
) )
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@@ -863,6 +871,7 @@ def train(args):
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None: if text_encoder_outputs_list is not None:
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list
if args.use_t5xxl_cache_only: if args.use_t5xxl_cache_only:
lg_out = None lg_out = None
@@ -878,7 +887,7 @@ def train(args):
if lg_out is None: if lg_out is None:
# not cached or training, so get from text encoders # not cached or training, so get from text encoders
input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"] input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"]
with torch.set_grad_enabled(args.train_text_encoder): with torch.set_grad_enabled(train_clip):
# TODO support weighted captions # TODO support weighted captions
# text models in sd3_models require "cpu" for input_ids # text models in sd3_models require "cpu" for input_ids
input_ids_clip_l = input_ids_clip_l.to("cpu") input_ids_clip_l = input_ids_clip_l.to("cpu")
@@ -891,7 +900,7 @@ def train(args):
if t5_out is None: if t5_out is None:
_, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"]
with torch.no_grad(): with torch.set_grad_enabled(train_t5xxl):
input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None
_, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens( _, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask]

View File

@@ -120,7 +120,13 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
return latents_caching_strategy return latents_caching_strategy
def get_text_encoding_strategy(self, args): def get_text_encoding_strategy(self, args):
return strategy_sd3.Sd3TextEncodingStrategy(args.apply_lg_attn_mask, args.apply_t5_attn_mask) return strategy_sd3.Sd3TextEncodingStrategy(
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
args.clip_l_dropout_rate,
args.clip_g_dropout_rate,
args.t5xxl_dropout_rate,
)
def post_process_network(self, args, accelerator, network, text_encoders, unet): def post_process_network(self, args, accelerator, network, text_encoders, unet):
# check t5xxl is trained or not # check t5xxl is trained or not
@@ -408,6 +414,14 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
text_encoder.to(te_weight_dtype) # fp8 text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype) prepare_fp8(text_encoder, weight_dtype)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
# drop cached text encoder outputs
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser() parser = train_network.setup_parser()

View File

@@ -272,6 +272,9 @@ class NetworkTrainer:
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
text_encoder.text_model.embeddings.to(dtype=weight_dtype) text_encoder.text_model.embeddings.to(dtype=weight_dtype)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
pass
# endregion # endregion
def train(self, args): def train(self, args):
@@ -1030,9 +1033,9 @@ class NetworkTrainer:
# callback for step start # callback for step start
if hasattr(accelerator.unwrap_model(network), "on_step_start"): if hasattr(accelerator.unwrap_model(network), "on_step_start"):
on_step_start = accelerator.unwrap_model(network).on_step_start on_step_start_for_network = accelerator.unwrap_model(network).on_step_start
else: else:
on_step_start = lambda *args, **kwargs: None on_step_start_for_network = lambda *args, **kwargs: None
# function for saving/removing # function for saving/removing
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
@@ -1113,7 +1116,10 @@ class NetworkTrainer:
continue continue
with accelerator.accumulate(training_model): with accelerator.accumulate(training_model):
on_step_start(text_encoder, unet) on_step_start_for_network(text_encoder, unet)
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
@@ -1146,6 +1152,7 @@ class NetworkTrainer:
if text_encoder_outputs_list is not None: if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning # Get the text embedding for conditioning
if args.weighted_captions: if args.weighted_captions: