diff --git a/sdxl_train.py b/sdxl_train.py index e62bc377..195467b0 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -5,6 +5,7 @@ import gc import math import os from multiprocessing import Value +from typing import List import toml from tqdm import tqdm @@ -30,6 +31,67 @@ from library.custom_train_functions import ( from library.sdxl_original_unet import SdxlUNet2DConditionModel +UNET_NUM_BLOCKS_FOR_BLOCK_LR = 23 + + +def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List[float]) -> List[dict]: + block_params = [[] for _ in range(len(block_lrs))] + + for i, (name, param) in enumerate(unet.named_parameters()): + if name.startswith("time_embed.") or name.startswith("label_emb."): + block_index = 0 # 0 + elif name.startswith("input_blocks."): # 1-9 + block_index = 1 + int(name.split(".")[1]) + elif name.startswith("middle_block."): # 10-12 + block_index = 10 + int(name.split(".")[1]) + elif name.startswith("output_blocks."): # 13-21 + block_index = 13 + int(name.split(".")[1]) + elif name.startswith("out."): # 22 + block_index = 22 + else: + raise ValueError(f"unexpected parameter name: {name}") + + block_params[block_index].append(param) + + params_to_optimize = [] + for i, params in enumerate(block_params): + if block_lrs[i] == 0: # 0のときは学習しない do not optimize when lr is 0 + continue + params_to_optimize.append({"params": params, "lr": block_lrs[i]}) + + return params_to_optimize + + +def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type): + lrs = lr_scheduler.get_last_lr() + + lr_index = 0 + block_index = 0 + while lr_index < len(lrs): + if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR: + name = f"block{block_index}" + if block_lrs[block_index] == 0: + block_index += 1 + continue + elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR: + name = "text_encoder1" + elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1: + name = "text_encoder2" + else: + raise ValueError(f"unexpected block_index: {block_index}") + + block_index += 1 + + logs["lr/" + name] = float(lrs[lr_index]) + + if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower(): + logs["lr/d*lr/" + name] = ( + lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"] + ) + + lr_index += 1 + + def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) @@ -40,6 +102,14 @@ def train(args): not args.train_text_encoder or not args.cache_text_encoder_outputs ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + if args.block_lr: + block_lrs = [float(lr) for lr in args.block_lr.split(",")] + assert ( + len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR + ), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください" + else: + block_lrs = None + cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -235,15 +305,28 @@ def train(args): for m in training_models: m.requires_grad_(True) - params = [] - for m in training_models: - params.extend(m.parameters()) - params_to_optimize = params - # calculate number of trainable parameters - n_params = 0 - for p in params: - n_params += p.numel() + if block_lrs is None: + params = [] + for m in training_models: + params.extend(m.parameters()) + params_to_optimize = params + + # calculate number of trainable parameters + n_params = 0 + for p in params: + n_params += p.numel() + else: + params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net + for m in training_models[1:]: # Text Encoders if exists + params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate}) + + # calculate number of trainable parameters + n_params = 0 + for params in params_to_optimize: + for p in params["params"]: + n_params += p.numel() + accelerator.print(f"number of models: {len(training_models)}") accelerator.print(f"number of trainable parameters: {n_params}") @@ -528,13 +611,18 @@ def train(args): current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy" - ): # tracking d*lr value - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] - ) + logs = {"loss": current_loss} + if block_lrs is None: + logs["lr"] = float(lr_scheduler.get_last_lr()[0]) + if ( + args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() + ): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + else: + append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) + accelerator.log(logs, step=global_step) # TODO moving averageにする @@ -638,6 +726,13 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--block_lr", + type=str, + default=None, + help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", + ) return parser