support block lr for U-Net

This commit is contained in:
Kohya S
2023-08-12 13:13:59 +09:00
parent 8415014de6
commit e2c2689f5c

View File

@@ -5,6 +5,7 @@ import gc
import math import math
import os import os
from multiprocessing import Value from multiprocessing import Value
from typing import List
import toml import toml
from tqdm import tqdm from tqdm import tqdm
@@ -30,6 +31,67 @@ from library.custom_train_functions import (
from library.sdxl_original_unet import SdxlUNet2DConditionModel from library.sdxl_original_unet import SdxlUNet2DConditionModel
UNET_NUM_BLOCKS_FOR_BLOCK_LR = 23
def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List[float]) -> List[dict]:
block_params = [[] for _ in range(len(block_lrs))]
for i, (name, param) in enumerate(unet.named_parameters()):
if name.startswith("time_embed.") or name.startswith("label_emb."):
block_index = 0 # 0
elif name.startswith("input_blocks."): # 1-9
block_index = 1 + int(name.split(".")[1])
elif name.startswith("middle_block."): # 10-12
block_index = 10 + int(name.split(".")[1])
elif name.startswith("output_blocks."): # 13-21
block_index = 13 + int(name.split(".")[1])
elif name.startswith("out."): # 22
block_index = 22
else:
raise ValueError(f"unexpected parameter name: {name}")
block_params[block_index].append(param)
params_to_optimize = []
for i, params in enumerate(block_params):
if block_lrs[i] == 0: # 0のときは学習しない do not optimize when lr is 0
continue
params_to_optimize.append({"params": params, "lr": block_lrs[i]})
return params_to_optimize
def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type):
lrs = lr_scheduler.get_last_lr()
lr_index = 0
block_index = 0
while lr_index < len(lrs):
if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR:
name = f"block{block_index}"
if block_lrs[block_index] == 0:
block_index += 1
continue
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR:
name = "text_encoder1"
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1:
name = "text_encoder2"
else:
raise ValueError(f"unexpected block_index: {block_index}")
block_index += 1
logs["lr/" + name] = float(lrs[lr_index])
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
logs["lr/d*lr/" + name] = (
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
)
lr_index += 1
def train(args): def train(args):
train_util.verify_training_args(args) train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True) train_util.prepare_dataset_args(args, True)
@@ -40,6 +102,14 @@ def train(args):
not args.train_text_encoder or not args.cache_text_encoder_outputs not args.train_text_encoder or not args.cache_text_encoder_outputs
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
if args.block_lr:
block_lrs = [float(lr) for lr in args.block_lr.split(",")]
assert (
len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR
), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください"
else:
block_lrs = None
cache_latents = args.cache_latents cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None use_dreambooth_method = args.in_json is None
@@ -235,15 +305,28 @@ def train(args):
for m in training_models: for m in training_models:
m.requires_grad_(True) m.requires_grad_(True)
params = []
for m in training_models:
params.extend(m.parameters())
params_to_optimize = params
# calculate number of trainable parameters if block_lrs is None:
n_params = 0 params = []
for p in params: for m in training_models:
n_params += p.numel() params.extend(m.parameters())
params_to_optimize = params
# calculate number of trainable parameters
n_params = 0
for p in params:
n_params += p.numel()
else:
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
for m in training_models[1:]: # Text Encoders if exists
params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate})
# calculate number of trainable parameters
n_params = 0
for params in params_to_optimize:
for p in params["params"]:
n_params += p.numel()
accelerator.print(f"number of models: {len(training_models)}") accelerator.print(f"number of models: {len(training_models)}")
accelerator.print(f"number of trainable parameters: {n_params}") accelerator.print(f"number of trainable parameters: {n_params}")
@@ -528,13 +611,18 @@ def train(args):
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} logs = {"loss": current_loss}
if ( if block_lrs is None:
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy" logs["lr"] = float(lr_scheduler.get_last_lr()[0])
): # tracking d*lr value if (
logs["lr/d*lr"] = ( args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ): # tracking d*lr value
) logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
else:
append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type)
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
# TODO moving averageにする # TODO moving averageにする
@@ -638,6 +726,13 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true", action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
) )
parser.add_argument(
"--block_lr",
type=str,
default=None,
help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
)
return parser return parser