add LBW support for SDXL merge LoRA

This commit is contained in:
Kohya S
2024-09-13 21:29:31 +09:00
parent f4a0bea6dc
commit b755ebd0a4
2 changed files with 76 additions and 11 deletions

View File

@@ -139,9 +139,17 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
### Sep 13, 2024 / 2024-09-13: ### Sep 13, 2024 / 2024-09-13:
- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). Will be included in the next release. - `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580).
- `svd_merge_lora.py` now supports LBW. Thanks to terracottahaniwa. See PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) for details.
- `sdxl_merge_lora.py` also supports LBW.
- See [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) by hako-mikan for details on LBW.
- These will be included in the next release.
- `sdxl_merge_lora.py` が OFT をサポートました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。次のリリースに含まれます。 - `sdxl_merge_lora.py` が OFT をサポートされました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。
- `svd_merge_lora.py` で LBW がサポートされました。PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) terracottahaniwa 氏に感謝します。
- `sdxl_merge_lora.py` でも LBW がサポートされました。
- LBW の詳細は hako-mikan 氏の [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) をご覧ください。
- 以上は次回リリースに含まれます。
### Jun 23, 2024 / 2024-06-23: ### Jun 23, 2024 / 2024-06-23:

View File

@@ -1,7 +1,9 @@
import itertools
import math import math
import argparse import argparse
import os import os
import time import time
import concurrent.futures
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from tqdm import tqdm from tqdm import tqdm
@@ -9,13 +11,13 @@ from library import sai_model_spec, sdxl_model_util, train_util
import library.model_util as model_util import library.model_util as model_util
import lora import lora
import oft import oft
from svd_merge_lora import format_lbws, get_lbw_block_index, LAYER26
from library.utils import setup_logging from library.utils import setup_logging
setup_logging() setup_logging()
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import concurrent.futures
def load_state_dict(file_name, dtype): def load_state_dict(file_name, dtype):
@@ -47,6 +49,7 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
def detect_method_from_training_model(models, dtype): def detect_method_from_training_model(models, dtype):
for model in models: for model in models:
# TODO It is better to use key names to detect the method
lora_sd, _ = load_state_dict(model, dtype) lora_sd, _ = load_state_dict(model, dtype)
for key in tqdm(lora_sd.keys()): for key in tqdm(lora_sd.keys()):
if "lora_up" in key or "lora_down" in key: if "lora_up" in key or "lora_down" in key:
@@ -55,15 +58,20 @@ def detect_method_from_training_model(models, dtype):
return "OFT" return "OFT"
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws, merge_dtype):
text_encoder1.to(merge_dtype)
text_encoder1.to(merge_dtype) text_encoder1.to(merge_dtype)
text_encoder2.to(merge_dtype)
unet.to(merge_dtype) unet.to(merge_dtype)
# detect the method: OFT or LoRA_module # detect the method: OFT or LoRA_module
method = detect_method_from_training_model(models, merge_dtype) method = detect_method_from_training_model(models, merge_dtype)
logger.info(f"method:{method}") logger.info(f"method:{method}")
if lbws:
lbws, _, LBW_TARGET_IDX = format_lbws(lbws)
else:
LBW_TARGET_IDX = []
# create module map # create module map
name_to_module = {} name_to_module = {}
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
@@ -94,12 +102,18 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
lora_name = lora_name.replace(".", "_") lora_name = lora_name.replace(".", "_")
name_to_module[lora_name] = child_module name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios): for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
logger.info(f"loading: {model}") logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype) lora_sd, _ = load_state_dict(model, merge_dtype)
logger.info(f"merging...") logger.info(f"merging...")
if lbw:
lbw_weights = [1] * 26
for index, value in zip(LBW_TARGET_IDX, lbw):
lbw_weights[index] = value
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
if method == "LoRA": if method == "LoRA":
for key in tqdm(lora_sd.keys()): for key in tqdm(lora_sd.keys()):
if "lora_down" in key: if "lora_down" in key:
@@ -121,6 +135,12 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
alpha = lora_sd.get(alpha_key, dim) alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim scale = alpha / dim
if lbw:
index = get_lbw_block_index(key, True)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
# W <- W + U * D # W <- W + U * D
weight = module.weight weight = module.weight
# logger.info(module_name, down_weight.size(), up_weight.size()) # logger.info(module_name, down_weight.size(), up_weight.size())
@@ -145,7 +165,6 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
elif method == "OFT": elif method == "OFT":
multiplier = 1.0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for key in tqdm(lora_sd.keys()): for key in tqdm(lora_sd.keys()):
@@ -183,6 +202,13 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
block_size = out_dim // dim block_size = out_dim // dim
constraint = (0 if alpha is None else alpha) * out_dim constraint = (0 if alpha is None else alpha) * out_dim
multiplier = 1
if lbw:
index = get_lbw_block_index(key, False)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
multiplier *= lbw_weights[index]
block_Q = oft_blocks - oft_blocks.transpose(1, 2) block_Q = oft_blocks - oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten()) norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=constraint) new_norm_Q = torch.clamp(norm_Q, max=constraint)
@@ -213,17 +239,35 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=False):
base_alphas = {} # alpha for merged model base_alphas = {} # alpha for merged model
base_dims = {} base_dims = {}
# detect the method: OFT or LoRA_module
method = detect_method_from_training_model(models, merge_dtype)
if method == "OFT":
raise ValueError(
"OFT model is not supported for merging OFT models. / OFTモデルはOFTモデル同士のマージには対応していません"
)
if lbws:
lbws, _, LBW_TARGET_IDX = format_lbws(lbws)
else:
LBW_TARGET_IDX = []
merged_sd = {} merged_sd = {}
v2 = None v2 = None
base_model = None base_model = None
for model, ratio in zip(models, ratios): for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
logger.info(f"loading: {model}") logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype) lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lbw:
lbw_weights = [1] * 26
for index, value in zip(LBW_TARGET_IDX, lbw):
lbw_weights[index] = value
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
if lora_metadata is not None: if lora_metadata is not None:
if v2 is None: if v2 is None:
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず
@@ -277,6 +321,12 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
scale = math.sqrt(alpha / base_alpha) * ratio scale = math.sqrt(alpha / base_alpha) * ratio
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
if lbw:
index = get_lbw_block_index(key, True)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
if key in merged_sd: if key in merged_sd:
assert ( assert (
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
@@ -329,6 +379,12 @@ def merge(args):
assert len(args.models) == len( assert len(args.models) == len(
args.ratios args.ratios
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
if args.lbws:
assert len(args.models) == len(
args.lbws
), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください"
else:
args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく
def str_to_dtype(p): def str_to_dtype(p):
if p == "float": if p == "float":
@@ -356,7 +412,7 @@ def merge(args):
ckpt_info, ckpt_info,
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu") ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu")
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype) merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, args.lbws, merge_dtype)
if args.no_metadata: if args.no_metadata:
sai_metadata = None sai_metadata = None
@@ -372,7 +428,7 @@ def merge(args):
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
) )
else: else:
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle)
logger.info(f"calculating hashes and creating metadata...") logger.info(f"calculating hashes and creating metadata...")
@@ -427,6 +483,7 @@ def setup_parser() -> argparse.ArgumentParser:
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
) )
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率")
parser.add_argument( parser.add_argument(
"--no_metadata", "--no_metadata",
action="store_true", action="store_true",