FLUX.1 LoRA supports CLIP-L

This commit is contained in:
Kohya S
2024-08-27 19:59:40 +09:00
parent 72287d39c7
commit 0087a46e14
6 changed files with 101 additions and 43 deletions

View File

@@ -127,8 +127,15 @@ class NetworkTrainer:
return None
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
"""
Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models.
"""
return text_encoders
# returns a list of bool values indicating whether each text encoder should be trained
def get_text_encoders_train_flags(self, args, text_encoders):
return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders)
def is_train_text_encoder(self, args):
return not args.network_train_unet_only
@@ -136,11 +143,6 @@ class NetworkTrainer:
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], weight_dtype)
return encoder_hidden_states
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
return noise_pred
@@ -313,7 +315,7 @@ class NetworkTrainer:
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
if args.debug_dataset:
train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
@@ -437,8 +439,10 @@ class NetworkTrainer:
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
for t_enc in text_encoders:
t_enc.gradient_checkpointing_enable()
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
if flag:
if t_enc.supports_gradient_checkpointing:
t_enc.gradient_checkpointing_enable()
del t_enc
network.enable_gradient_checkpointing() # may have no effect
@@ -522,14 +526,17 @@ class NetworkTrainer:
unet_weight_dtype = te_weight_dtype = weight_dtype
# Experimental Feature: Put base model into fp8 to save vram
if args.fp8_base:
if args.fp8_base or args.fp8_base_unet:
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
assert (
args.mixed_precision != "no"
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
accelerator.print("enable fp8 training.")
accelerator.print("enable fp8 training for U-Net.")
unet_weight_dtype = torch.float8_e4m3fn
te_weight_dtype = torch.float8_e4m3fn
if not args.fp8_base_unet:
accelerator.print("enable fp8 training for Text Encoder.")
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
@@ -546,19 +553,18 @@ class NetworkTrainer:
t_enc.to(dtype=te_weight_dtype)
if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"):
# nn.Embedding not support FP8
t_enc.text_model.embeddings.to(
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"):
t_enc.encoder.embeddings.to(
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if args.deepspeed:
flags = self.get_text_encoders_train_flags(args, text_encoders)
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
unet=unet if train_unet else None,
text_encoder1=text_encoders[0] if train_text_encoder else None,
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
text_encoder1=text_encoders[0] if flags[0] else None,
text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None,
network=network,
)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
@@ -571,11 +577,14 @@ class NetworkTrainer:
else:
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
text_encoders = [
(accelerator.prepare(t_enc) if flag else t_enc)
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))
]
if len(text_encoders) > 1:
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
text_encoder = text_encoders
else:
text_encoder = accelerator.prepare(text_encoder)
text_encoders = [text_encoder]
text_encoder = text_encoders[0]
else:
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
@@ -587,11 +596,11 @@ class NetworkTrainer:
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
unet.train()
for t_enc in text_encoders:
for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
t_enc.train()
# set top parameter requires_grad = True for gradient checkpointing works
if train_text_encoder:
if frag:
t_enc.text_model.embeddings.requires_grad_(True)
else:
@@ -736,6 +745,7 @@ class NetworkTrainer:
"ss_huber_schedule": args.huber_schedule,
"ss_huber_c": args.huber_c,
"ss_fp8_base": args.fp8_base,
"ss_fp8_base_unet": args.fp8_base_unet,
}
self.update_metadata(metadata, args) # architecture specific metadata
@@ -1004,6 +1014,7 @@ class NetworkTrainer:
for t_enc in text_encoders:
del t_enc
text_encoders = []
text_encoder = None
# For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
@@ -1018,7 +1029,7 @@ class NetworkTrainer:
# log device and dtype for each model
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
for t_enc in text_encoders:
logger.info(f"text_encoder dtype: {te_weight_dtype}, device: {t_enc.device}")
logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}")
clean_memory_on_device(accelerator.device)
@@ -1073,12 +1084,17 @@ class NetworkTrainer:
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
else:
if (
text_encoder_conds is None
or len(text_encoder_conds) == 0
or text_encoder_conds[0] is None
or train_text_encoder
):
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
# SD only
text_encoder_conds = get_weighted_text_embeddings(
encoded_text_encoder_conds = get_weighted_text_embeddings(
tokenizers[0],
text_encoder,
batch["captions"],
@@ -1088,13 +1104,18 @@ class NetworkTrainer:
)
else:
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
text_encoder_conds = text_encoding_strategy.encode_tokens(
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
)
if args.full_fp16:
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
# sample noise, call unet, get target
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
@@ -1257,6 +1278,12 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
parser.add_argument(
"--fp8_base_unet",
action="store_true",
help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16"
" / U-NetまたはDiTにfp8を使用する。Text Encoderはfp16またはbf16",
)
parser.add_argument(
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"