mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support block lr for U-Net
This commit is contained in:
@@ -5,6 +5,7 @@ import gc
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
from typing import List
|
||||||
import toml
|
import toml
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -30,6 +31,67 @@ from library.custom_train_functions import (
|
|||||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||||
|
|
||||||
|
|
||||||
|
UNET_NUM_BLOCKS_FOR_BLOCK_LR = 23
|
||||||
|
|
||||||
|
|
||||||
|
def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List[float]) -> List[dict]:
|
||||||
|
block_params = [[] for _ in range(len(block_lrs))]
|
||||||
|
|
||||||
|
for i, (name, param) in enumerate(unet.named_parameters()):
|
||||||
|
if name.startswith("time_embed.") or name.startswith("label_emb."):
|
||||||
|
block_index = 0 # 0
|
||||||
|
elif name.startswith("input_blocks."): # 1-9
|
||||||
|
block_index = 1 + int(name.split(".")[1])
|
||||||
|
elif name.startswith("middle_block."): # 10-12
|
||||||
|
block_index = 10 + int(name.split(".")[1])
|
||||||
|
elif name.startswith("output_blocks."): # 13-21
|
||||||
|
block_index = 13 + int(name.split(".")[1])
|
||||||
|
elif name.startswith("out."): # 22
|
||||||
|
block_index = 22
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unexpected parameter name: {name}")
|
||||||
|
|
||||||
|
block_params[block_index].append(param)
|
||||||
|
|
||||||
|
params_to_optimize = []
|
||||||
|
for i, params in enumerate(block_params):
|
||||||
|
if block_lrs[i] == 0: # 0のときは学習しない do not optimize when lr is 0
|
||||||
|
continue
|
||||||
|
params_to_optimize.append({"params": params, "lr": block_lrs[i]})
|
||||||
|
|
||||||
|
return params_to_optimize
|
||||||
|
|
||||||
|
|
||||||
|
def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type):
|
||||||
|
lrs = lr_scheduler.get_last_lr()
|
||||||
|
|
||||||
|
lr_index = 0
|
||||||
|
block_index = 0
|
||||||
|
while lr_index < len(lrs):
|
||||||
|
if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR:
|
||||||
|
name = f"block{block_index}"
|
||||||
|
if block_lrs[block_index] == 0:
|
||||||
|
block_index += 1
|
||||||
|
continue
|
||||||
|
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR:
|
||||||
|
name = "text_encoder1"
|
||||||
|
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1:
|
||||||
|
name = "text_encoder2"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unexpected block_index: {block_index}")
|
||||||
|
|
||||||
|
block_index += 1
|
||||||
|
|
||||||
|
logs["lr/" + name] = float(lrs[lr_index])
|
||||||
|
|
||||||
|
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
|
||||||
|
logs["lr/d*lr/" + name] = (
|
||||||
|
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_index += 1
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
@@ -40,6 +102,14 @@ def train(args):
|
|||||||
not args.train_text_encoder or not args.cache_text_encoder_outputs
|
not args.train_text_encoder or not args.cache_text_encoder_outputs
|
||||||
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
|
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
|
||||||
|
|
||||||
|
if args.block_lr:
|
||||||
|
block_lrs = [float(lr) for lr in args.block_lr.split(",")]
|
||||||
|
assert (
|
||||||
|
len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR
|
||||||
|
), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください"
|
||||||
|
else:
|
||||||
|
block_lrs = None
|
||||||
|
|
||||||
cache_latents = args.cache_latents
|
cache_latents = args.cache_latents
|
||||||
use_dreambooth_method = args.in_json is None
|
use_dreambooth_method = args.in_json is None
|
||||||
|
|
||||||
@@ -235,6 +305,8 @@ def train(args):
|
|||||||
|
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
m.requires_grad_(True)
|
m.requires_grad_(True)
|
||||||
|
|
||||||
|
if block_lrs is None:
|
||||||
params = []
|
params = []
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
params.extend(m.parameters())
|
params.extend(m.parameters())
|
||||||
@@ -244,6 +316,17 @@ def train(args):
|
|||||||
n_params = 0
|
n_params = 0
|
||||||
for p in params:
|
for p in params:
|
||||||
n_params += p.numel()
|
n_params += p.numel()
|
||||||
|
else:
|
||||||
|
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
|
||||||
|
for m in training_models[1:]: # Text Encoders if exists
|
||||||
|
params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate})
|
||||||
|
|
||||||
|
# calculate number of trainable parameters
|
||||||
|
n_params = 0
|
||||||
|
for params in params_to_optimize:
|
||||||
|
for p in params["params"]:
|
||||||
|
n_params += p.numel()
|
||||||
|
|
||||||
accelerator.print(f"number of models: {len(training_models)}")
|
accelerator.print(f"number of models: {len(training_models)}")
|
||||||
accelerator.print(f"number of trainable parameters: {n_params}")
|
accelerator.print(f"number of trainable parameters: {n_params}")
|
||||||
|
|
||||||
@@ -528,13 +611,18 @@ def train(args):
|
|||||||
|
|
||||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
logs = {"loss": current_loss}
|
||||||
|
if block_lrs is None:
|
||||||
|
logs["lr"] = float(lr_scheduler.get_last_lr()[0])
|
||||||
if (
|
if (
|
||||||
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy"
|
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
|
||||||
): # tracking d*lr value
|
): # tracking d*lr value
|
||||||
logs["lr/d*lr"] = (
|
logs["lr/d*lr"] = (
|
||||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type)
|
||||||
|
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
# TODO moving averageにする
|
# TODO moving averageにする
|
||||||
@@ -638,6 +726,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--block_lr",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
|
||||||
|
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user