diff --git a/sdxl_train.py b/sdxl_train.py index f067acd5..47bc6a42 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -10,10 +10,13 @@ import toml from tqdm import tqdm import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -272,10 +275,11 @@ def train(args): accelerator.wait_for_everyone() # 学習を準備する:モデルを適切な状態にする - training_models = [] if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - training_models.append(unet) + train_unet = args.learning_rate > 0 + train_text_encoder1 = False + train_text_encoder2 = False if args.train_text_encoder: # TODO each option for two text encoders? @@ -283,9 +287,20 @@ def train(args): if args.gradient_checkpointing: text_encoder1.gradient_checkpointing_enable() text_encoder2.gradient_checkpointing_enable() - training_models.append(text_encoder1) - training_models.append(text_encoder2) - # set require_grad=True later + lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train + lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train + train_text_encoder1 = lr_te1 > 0 + train_text_encoder2 = lr_te2 > 0 + + # caching one text encoder output is not supported + if not train_text_encoder1: + text_encoder1.to(weight_dtype) + if not train_text_encoder2: + text_encoder2.to(weight_dtype) + text_encoder1.requires_grad_(train_text_encoder1) + text_encoder2.requires_grad_(train_text_encoder2) + text_encoder1.train(train_text_encoder1) + text_encoder2.train(train_text_encoder2) else: text_encoder1.to(weight_dtype) text_encoder2.to(weight_dtype) @@ -313,21 +328,25 @@ def train(args): vae.eval() vae.to(accelerator.device, dtype=vae_dtype) - for m in training_models: - m.requires_grad_(True) + unet.requires_grad_(train_unet) + if not train_unet: + unet.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared - if block_lrs is None: - params_to_optimize = [ - {"params": list(training_models[0].parameters()), "lr": args.learning_rate}, - ] - else: - params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net + training_models = [] + params_to_optimize = [] + if train_unet: + training_models.append(unet) + if block_lrs is None: + params_to_optimize.append({"params": list(unet.parameters()), "lr": args.learning_rate}) + else: + params_to_optimize.extend(get_block_params_to_optimize(unet, block_lrs)) - for m in training_models[1:]: # Text Encoders if exists - params_to_optimize.append({ - "params": list(m.parameters()), - "lr": args.learning_rate_te or args.learning_rate - }) + if train_text_encoder1: + training_models.append(text_encoder1) + params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) + if train_text_encoder2: + training_models.append(text_encoder2) + params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) # calculate number of trainable parameters n_params = 0 @@ -335,6 +354,7 @@ def train(args): for p in params["params"]: n_params += p.numel() + accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}") accelerator.print(f"number of models: {len(training_models)}") accelerator.print(f"number of trainable parameters: {n_params}") @@ -386,16 +406,17 @@ def train(args): text_encoder2.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - if args.train_text_encoder: - unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler - ) - - # transform DDP after prepare - text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet]) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if train_unet: + unet = accelerator.prepare(unet) (unet,) = train_util.transform_models_if_DDP([unet]) + if train_text_encoder1: + text_encoder1 = accelerator.prepare(text_encoder1) + (text_encoder1,) = train_util.transform_models_if_DDP([text_encoder1]) + if train_text_encoder2: + text_encoder2 = accelerator.prepare(text_encoder2) + (text_encoder2,) = train_util.transform_models_if_DDP([text_encoder2]) + + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: @@ -461,7 +482,7 @@ def train(args): for step, batch in enumerate(train_dataloader): current_step.value = global_step - with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく + with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: @@ -547,7 +568,12 @@ def train(args): target = noise - if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss or args.debiased_estimation_loss: + if ( + args.min_snr_gamma + or args.scale_v_pred_loss_like_noise_pred + or args.v_pred_like_loss + or args.debiased_estimation_loss + ): # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) @@ -725,7 +751,19 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) sdxl_train_util.add_sdxl_training_arguments(parser) - parser.add_argument("--learning_rate_te", type=float, default=0.0, help="learning rate for text encoder") + + parser.add_argument( + "--learning_rate_te1", + type=float, + default=None, + help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率", + ) + parser.add_argument( + "--learning_rate_te2", + type=float, + default=None, + help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", + ) parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")