mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'dev' into sd3
This commit is contained in:
12
README.md
12
README.md
@@ -707,9 +707,17 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します
|
|||||||
|
|
||||||
### Sep 13, 2024 / 2024-09-13:
|
### Sep 13, 2024 / 2024-09-13:
|
||||||
|
|
||||||
- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). Will be included in the next release.
|
- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580).
|
||||||
|
- `svd_merge_lora.py` now supports LBW. Thanks to terracottahaniwa. See PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) for details.
|
||||||
|
- `sdxl_merge_lora.py` also supports LBW.
|
||||||
|
- See [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) by hako-mikan for details on LBW.
|
||||||
|
- These will be included in the next release.
|
||||||
|
|
||||||
- `sdxl_merge_lora.py` が OFT をサポートしました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。次のリリースに含まれます。
|
- `sdxl_merge_lora.py` が OFT をサポートされました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。
|
||||||
|
- `svd_merge_lora.py` で LBW がサポートされました。PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) terracottahaniwa 氏に感謝します。
|
||||||
|
- `sdxl_merge_lora.py` でも LBW がサポートされました。
|
||||||
|
- LBW の詳細は hako-mikan 氏の [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) をご覧ください。
|
||||||
|
- 以上は次回リリースに含まれます。
|
||||||
|
|
||||||
### Jun 23, 2024 / 2024-06-23:
|
### Jun 23, 2024 / 2024-06-23:
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
|
import itertools
|
||||||
import math
|
import math
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import concurrent.futures
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -9,13 +11,13 @@ from library import sai_model_spec, sdxl_model_util, train_util
|
|||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import lora
|
import lora
|
||||||
import oft
|
import oft
|
||||||
|
from svd_merge_lora import format_lbws, get_lbw_block_index, LAYER26
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
import concurrent.futures
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(file_name, dtype):
|
def load_state_dict(file_name, dtype):
|
||||||
@@ -47,6 +49,7 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
|
|||||||
|
|
||||||
def detect_method_from_training_model(models, dtype):
|
def detect_method_from_training_model(models, dtype):
|
||||||
for model in models:
|
for model in models:
|
||||||
|
# TODO It is better to use key names to detect the method
|
||||||
lora_sd, _ = load_state_dict(model, dtype)
|
lora_sd, _ = load_state_dict(model, dtype)
|
||||||
for key in tqdm(lora_sd.keys()):
|
for key in tqdm(lora_sd.keys()):
|
||||||
if "lora_up" in key or "lora_down" in key:
|
if "lora_up" in key or "lora_down" in key:
|
||||||
@@ -55,15 +58,20 @@ def detect_method_from_training_model(models, dtype):
|
|||||||
return "OFT"
|
return "OFT"
|
||||||
|
|
||||||
|
|
||||||
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
|
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws, merge_dtype):
|
||||||
text_encoder1.to(merge_dtype)
|
|
||||||
text_encoder1.to(merge_dtype)
|
text_encoder1.to(merge_dtype)
|
||||||
|
text_encoder2.to(merge_dtype)
|
||||||
unet.to(merge_dtype)
|
unet.to(merge_dtype)
|
||||||
|
|
||||||
# detect the method: OFT or LoRA_module
|
# detect the method: OFT or LoRA_module
|
||||||
method = detect_method_from_training_model(models, merge_dtype)
|
method = detect_method_from_training_model(models, merge_dtype)
|
||||||
logger.info(f"method:{method}")
|
logger.info(f"method:{method}")
|
||||||
|
|
||||||
|
if lbws:
|
||||||
|
lbws, _, LBW_TARGET_IDX = format_lbws(lbws)
|
||||||
|
else:
|
||||||
|
LBW_TARGET_IDX = []
|
||||||
|
|
||||||
# create module map
|
# create module map
|
||||||
name_to_module = {}
|
name_to_module = {}
|
||||||
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
|
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
|
||||||
@@ -94,12 +102,18 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
|||||||
lora_name = lora_name.replace(".", "_")
|
lora_name = lora_name.replace(".", "_")
|
||||||
name_to_module[lora_name] = child_module
|
name_to_module[lora_name] = child_module
|
||||||
|
|
||||||
for model, ratio in zip(models, ratios):
|
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
|
||||||
logger.info(f"loading: {model}")
|
logger.info(f"loading: {model}")
|
||||||
lora_sd, _ = load_state_dict(model, merge_dtype)
|
lora_sd, _ = load_state_dict(model, merge_dtype)
|
||||||
|
|
||||||
logger.info(f"merging...")
|
logger.info(f"merging...")
|
||||||
|
|
||||||
|
if lbw:
|
||||||
|
lbw_weights = [1] * 26
|
||||||
|
for index, value in zip(LBW_TARGET_IDX, lbw):
|
||||||
|
lbw_weights[index] = value
|
||||||
|
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
|
||||||
|
|
||||||
if method == "LoRA":
|
if method == "LoRA":
|
||||||
for key in tqdm(lora_sd.keys()):
|
for key in tqdm(lora_sd.keys()):
|
||||||
if "lora_down" in key:
|
if "lora_down" in key:
|
||||||
@@ -121,6 +135,12 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
|||||||
alpha = lora_sd.get(alpha_key, dim)
|
alpha = lora_sd.get(alpha_key, dim)
|
||||||
scale = alpha / dim
|
scale = alpha / dim
|
||||||
|
|
||||||
|
if lbw:
|
||||||
|
index = get_lbw_block_index(key, True)
|
||||||
|
is_lbw_target = index in LBW_TARGET_IDX
|
||||||
|
if is_lbw_target:
|
||||||
|
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
|
||||||
|
|
||||||
# W <- W + U * D
|
# W <- W + U * D
|
||||||
weight = module.weight
|
weight = module.weight
|
||||||
# logger.info(module_name, down_weight.size(), up_weight.size())
|
# logger.info(module_name, down_weight.size(), up_weight.size())
|
||||||
@@ -145,7 +165,6 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
|||||||
|
|
||||||
elif method == "OFT":
|
elif method == "OFT":
|
||||||
|
|
||||||
multiplier = 1.0
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
for key in tqdm(lora_sd.keys()):
|
for key in tqdm(lora_sd.keys()):
|
||||||
@@ -183,6 +202,13 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
|||||||
block_size = out_dim // dim
|
block_size = out_dim // dim
|
||||||
constraint = (0 if alpha is None else alpha) * out_dim
|
constraint = (0 if alpha is None else alpha) * out_dim
|
||||||
|
|
||||||
|
multiplier = 1
|
||||||
|
if lbw:
|
||||||
|
index = get_lbw_block_index(key, False)
|
||||||
|
is_lbw_target = index in LBW_TARGET_IDX
|
||||||
|
if is_lbw_target:
|
||||||
|
multiplier *= lbw_weights[index]
|
||||||
|
|
||||||
block_Q = oft_blocks - oft_blocks.transpose(1, 2)
|
block_Q = oft_blocks - oft_blocks.transpose(1, 2)
|
||||||
norm_Q = torch.norm(block_Q.flatten())
|
norm_Q = torch.norm(block_Q.flatten())
|
||||||
new_norm_Q = torch.clamp(norm_Q, max=constraint)
|
new_norm_Q = torch.clamp(norm_Q, max=constraint)
|
||||||
@@ -213,17 +239,35 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
|||||||
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))
|
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))
|
||||||
|
|
||||||
|
|
||||||
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=False):
|
||||||
base_alphas = {} # alpha for merged model
|
base_alphas = {} # alpha for merged model
|
||||||
base_dims = {}
|
base_dims = {}
|
||||||
|
|
||||||
|
# detect the method: OFT or LoRA_module
|
||||||
|
method = detect_method_from_training_model(models, merge_dtype)
|
||||||
|
if method == "OFT":
|
||||||
|
raise ValueError(
|
||||||
|
"OFT model is not supported for merging OFT models. / OFTモデルはOFTモデル同士のマージには対応していません"
|
||||||
|
)
|
||||||
|
|
||||||
|
if lbws:
|
||||||
|
lbws, _, LBW_TARGET_IDX = format_lbws(lbws)
|
||||||
|
else:
|
||||||
|
LBW_TARGET_IDX = []
|
||||||
|
|
||||||
merged_sd = {}
|
merged_sd = {}
|
||||||
v2 = None
|
v2 = None
|
||||||
base_model = None
|
base_model = None
|
||||||
for model, ratio in zip(models, ratios):
|
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
|
||||||
logger.info(f"loading: {model}")
|
logger.info(f"loading: {model}")
|
||||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||||
|
|
||||||
|
if lbw:
|
||||||
|
lbw_weights = [1] * 26
|
||||||
|
for index, value in zip(LBW_TARGET_IDX, lbw):
|
||||||
|
lbw_weights[index] = value
|
||||||
|
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
|
||||||
|
|
||||||
if lora_metadata is not None:
|
if lora_metadata is not None:
|
||||||
if v2 is None:
|
if v2 is None:
|
||||||
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず
|
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず
|
||||||
@@ -277,6 +321,12 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
|||||||
scale = math.sqrt(alpha / base_alpha) * ratio
|
scale = math.sqrt(alpha / base_alpha) * ratio
|
||||||
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
|
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
|
||||||
|
|
||||||
|
if lbw:
|
||||||
|
index = get_lbw_block_index(key, True)
|
||||||
|
is_lbw_target = index in LBW_TARGET_IDX
|
||||||
|
if is_lbw_target:
|
||||||
|
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
|
||||||
|
|
||||||
if key in merged_sd:
|
if key in merged_sd:
|
||||||
assert (
|
assert (
|
||||||
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
|
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
|
||||||
@@ -329,6 +379,12 @@ def merge(args):
|
|||||||
assert len(args.models) == len(
|
assert len(args.models) == len(
|
||||||
args.ratios
|
args.ratios
|
||||||
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||||
|
if args.lbws:
|
||||||
|
assert len(args.models) == len(
|
||||||
|
args.lbws
|
||||||
|
), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください"
|
||||||
|
else:
|
||||||
|
args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく
|
||||||
|
|
||||||
def str_to_dtype(p):
|
def str_to_dtype(p):
|
||||||
if p == "float":
|
if p == "float":
|
||||||
@@ -356,7 +412,7 @@ def merge(args):
|
|||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu")
|
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu")
|
||||||
|
|
||||||
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)
|
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, args.lbws, merge_dtype)
|
||||||
|
|
||||||
if args.no_metadata:
|
if args.no_metadata:
|
||||||
sai_metadata = None
|
sai_metadata = None
|
||||||
@@ -372,7 +428,7 @@ def merge(args):
|
|||||||
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
|
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle)
|
||||||
|
|
||||||
logger.info(f"calculating hashes and creating metadata...")
|
logger.info(f"calculating hashes and creating metadata...")
|
||||||
|
|
||||||
@@ -427,6 +483,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
|
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
|
||||||
)
|
)
|
||||||
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||||
|
parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no_metadata",
|
"--no_metadata",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import itertools
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
@@ -8,12 +11,194 @@ from library import sai_model_spec, train_util
|
|||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import lora
|
import lora
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
CLAMP_QUANTILE = 0.99
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
|
ACCEPTABLE = [12, 17, 20, 26]
|
||||||
|
SDXL_LAYER_NUM = [12, 20]
|
||||||
|
|
||||||
|
LAYER12 = {
|
||||||
|
"BASE": True,
|
||||||
|
"IN00": False,
|
||||||
|
"IN01": False,
|
||||||
|
"IN02": False,
|
||||||
|
"IN03": False,
|
||||||
|
"IN04": True,
|
||||||
|
"IN05": True,
|
||||||
|
"IN06": False,
|
||||||
|
"IN07": True,
|
||||||
|
"IN08": True,
|
||||||
|
"IN09": False,
|
||||||
|
"IN10": False,
|
||||||
|
"IN11": False,
|
||||||
|
"MID": True,
|
||||||
|
"OUT00": True,
|
||||||
|
"OUT01": True,
|
||||||
|
"OUT02": True,
|
||||||
|
"OUT03": True,
|
||||||
|
"OUT04": True,
|
||||||
|
"OUT05": True,
|
||||||
|
"OUT06": False,
|
||||||
|
"OUT07": False,
|
||||||
|
"OUT08": False,
|
||||||
|
"OUT09": False,
|
||||||
|
"OUT10": False,
|
||||||
|
"OUT11": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
LAYER17 = {
|
||||||
|
"BASE": True,
|
||||||
|
"IN00": False,
|
||||||
|
"IN01": True,
|
||||||
|
"IN02": True,
|
||||||
|
"IN03": False,
|
||||||
|
"IN04": True,
|
||||||
|
"IN05": True,
|
||||||
|
"IN06": False,
|
||||||
|
"IN07": True,
|
||||||
|
"IN08": True,
|
||||||
|
"IN09": False,
|
||||||
|
"IN10": False,
|
||||||
|
"IN11": False,
|
||||||
|
"MID": True,
|
||||||
|
"OUT00": False,
|
||||||
|
"OUT01": False,
|
||||||
|
"OUT02": False,
|
||||||
|
"OUT03": True,
|
||||||
|
"OUT04": True,
|
||||||
|
"OUT05": True,
|
||||||
|
"OUT06": True,
|
||||||
|
"OUT07": True,
|
||||||
|
"OUT08": True,
|
||||||
|
"OUT09": True,
|
||||||
|
"OUT10": True,
|
||||||
|
"OUT11": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
LAYER20 = {
|
||||||
|
"BASE": True,
|
||||||
|
"IN00": True,
|
||||||
|
"IN01": True,
|
||||||
|
"IN02": True,
|
||||||
|
"IN03": True,
|
||||||
|
"IN04": True,
|
||||||
|
"IN05": True,
|
||||||
|
"IN06": True,
|
||||||
|
"IN07": True,
|
||||||
|
"IN08": True,
|
||||||
|
"IN09": False,
|
||||||
|
"IN10": False,
|
||||||
|
"IN11": False,
|
||||||
|
"MID": True,
|
||||||
|
"OUT00": True,
|
||||||
|
"OUT01": True,
|
||||||
|
"OUT02": True,
|
||||||
|
"OUT03": True,
|
||||||
|
"OUT04": True,
|
||||||
|
"OUT05": True,
|
||||||
|
"OUT06": True,
|
||||||
|
"OUT07": True,
|
||||||
|
"OUT08": True,
|
||||||
|
"OUT09": False,
|
||||||
|
"OUT10": False,
|
||||||
|
"OUT11": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
LAYER26 = {
|
||||||
|
"BASE": True,
|
||||||
|
"IN00": True,
|
||||||
|
"IN01": True,
|
||||||
|
"IN02": True,
|
||||||
|
"IN03": True,
|
||||||
|
"IN04": True,
|
||||||
|
"IN05": True,
|
||||||
|
"IN06": True,
|
||||||
|
"IN07": True,
|
||||||
|
"IN08": True,
|
||||||
|
"IN09": True,
|
||||||
|
"IN10": True,
|
||||||
|
"IN11": True,
|
||||||
|
"MID": True,
|
||||||
|
"OUT00": True,
|
||||||
|
"OUT01": True,
|
||||||
|
"OUT02": True,
|
||||||
|
"OUT03": True,
|
||||||
|
"OUT04": True,
|
||||||
|
"OUT05": True,
|
||||||
|
"OUT06": True,
|
||||||
|
"OUT07": True,
|
||||||
|
"OUT08": True,
|
||||||
|
"OUT09": True,
|
||||||
|
"OUT10": True,
|
||||||
|
"OUT11": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert len([v for v in LAYER12.values() if v]) == 12
|
||||||
|
assert len([v for v in LAYER17.values() if v]) == 17
|
||||||
|
assert len([v for v in LAYER20.values() if v]) == 20
|
||||||
|
assert len([v for v in LAYER26.values() if v]) == 26
|
||||||
|
|
||||||
|
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||||
|
|
||||||
|
|
||||||
|
def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int:
|
||||||
|
# lbw block index is 0-based, but 0 for text encoder, so we return 0 for text encoder
|
||||||
|
if "text_model_encoder_" in lora_name: # LoRA for text encoder
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# lbw block index is 1-based for U-Net, and no "input_blocks.0" in CompVis SD, so "input_blocks.1" have index 2
|
||||||
|
block_idx = -1 # invalid lora name
|
||||||
|
if not is_sdxl:
|
||||||
|
NUM_OF_BLOCKS = 12 # up/down blocks
|
||||||
|
m = RE_UPDOWN.search(lora_name)
|
||||||
|
if m:
|
||||||
|
g = m.groups()
|
||||||
|
up_down = g[0]
|
||||||
|
i = int(g[1])
|
||||||
|
j = int(g[3])
|
||||||
|
if up_down == "down":
|
||||||
|
if g[2] == "resnets" or g[2] == "attentions":
|
||||||
|
idx = 3 * i + j + 1
|
||||||
|
elif g[2] == "downsamplers":
|
||||||
|
idx = 3 * (i + 1)
|
||||||
|
else:
|
||||||
|
return block_idx # invalid lora name
|
||||||
|
elif up_down == "up":
|
||||||
|
if g[2] == "resnets" or g[2] == "attentions":
|
||||||
|
idx = 3 * i + j
|
||||||
|
elif g[2] == "upsamplers":
|
||||||
|
idx = 3 * i + 2
|
||||||
|
else:
|
||||||
|
return block_idx # invalid lora name
|
||||||
|
|
||||||
|
if g[0] == "down":
|
||||||
|
block_idx = 1 + idx # 1-based index, down block index
|
||||||
|
elif g[0] == "up":
|
||||||
|
block_idx = 1 + NUM_OF_BLOCKS + 1 + idx # 1-based index, num blocks, mid block, up block index
|
||||||
|
|
||||||
|
elif "mid_block_" in lora_name:
|
||||||
|
block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block
|
||||||
|
else:
|
||||||
|
if lora_name.startswith("lora_unet_"):
|
||||||
|
name = lora_name[len("lora_unet_") :]
|
||||||
|
if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts
|
||||||
|
block_idx = 1
|
||||||
|
elif name.startswith("input_blocks_"): # 1-8 to 2-9
|
||||||
|
block_idx = 1 + int(name.split("_")[2])
|
||||||
|
elif name.startswith("middle_block_"): # 10
|
||||||
|
block_idx = 10
|
||||||
|
elif name.startswith("output_blocks_"): # 0-8 to 11-19
|
||||||
|
block_idx = 11 + int(name.split("_")[2])
|
||||||
|
elif name.startswith("out_"): # 20, No LoRA in sd-scripts
|
||||||
|
block_idx = 20
|
||||||
|
|
||||||
|
return block_idx
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(file_name, dtype):
|
def load_state_dict(file_name, dtype):
|
||||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||||
@@ -42,12 +227,46 @@ def save_to_file(file_name, state_dict, dtype, metadata):
|
|||||||
torch.save(state_dict, file_name)
|
torch.save(state_dict, file_name)
|
||||||
|
|
||||||
|
|
||||||
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
def format_lbws(lbws):
|
||||||
|
try:
|
||||||
|
# lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している
|
||||||
|
lbws = [json.loads(lbw) for lbw in lbws]
|
||||||
|
except Exception:
|
||||||
|
raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください")
|
||||||
|
assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください"
|
||||||
|
assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください"
|
||||||
|
assert all(
|
||||||
|
len(lbw) in ACCEPTABLE for lbw in lbws
|
||||||
|
), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください"
|
||||||
|
assert all(
|
||||||
|
all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws
|
||||||
|
), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください"
|
||||||
|
|
||||||
|
layer_num = len(lbws[0])
|
||||||
|
is_sdxl = True if layer_num in SDXL_LAYER_NUM else False
|
||||||
|
FLAGS = {
|
||||||
|
"12": LAYER12.values(),
|
||||||
|
"17": LAYER17.values(),
|
||||||
|
"20": LAYER20.values(),
|
||||||
|
"26": LAYER26.values(),
|
||||||
|
}[str(layer_num)]
|
||||||
|
LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag]
|
||||||
|
return lbws, is_sdxl, LBW_TARGET_IDX
|
||||||
|
|
||||||
|
|
||||||
|
def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype):
|
||||||
logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
||||||
merged_sd = {}
|
merged_sd = {}
|
||||||
v2 = None
|
v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2
|
||||||
base_model = None
|
base_model = None
|
||||||
for model, ratio in zip(models, ratios):
|
|
||||||
|
if lbws:
|
||||||
|
lbws, is_sdxl, LBW_TARGET_IDX = format_lbws(lbws)
|
||||||
|
else:
|
||||||
|
is_sdxl = False
|
||||||
|
LBW_TARGET_IDX = []
|
||||||
|
|
||||||
|
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
|
||||||
logger.info(f"loading: {model}")
|
logger.info(f"loading: {model}")
|
||||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||||
|
|
||||||
@@ -57,6 +276,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|||||||
if base_model is None:
|
if base_model is None:
|
||||||
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
|
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
|
||||||
|
|
||||||
|
if lbw:
|
||||||
|
lbw_weights = [1] * 26
|
||||||
|
for index, value in zip(LBW_TARGET_IDX, lbw):
|
||||||
|
lbw_weights[index] = value
|
||||||
|
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
|
||||||
|
|
||||||
# merge
|
# merge
|
||||||
logger.info(f"merging...")
|
logger.info(f"merging...")
|
||||||
for key in tqdm(list(lora_sd.keys())):
|
for key in tqdm(list(lora_sd.keys())):
|
||||||
@@ -93,6 +318,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|||||||
# W <- W + U * D
|
# W <- W + U * D
|
||||||
scale = alpha / network_dim
|
scale = alpha / network_dim
|
||||||
|
|
||||||
|
if lbw:
|
||||||
|
index = get_lbw_block_index(key, is_sdxl)
|
||||||
|
is_lbw_target = index in LBW_TARGET_IDX
|
||||||
|
if is_lbw_target:
|
||||||
|
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
|
||||||
|
|
||||||
if device: # and isinstance(scale, torch.Tensor):
|
if device: # and isinstance(scale, torch.Tensor):
|
||||||
scale = scale.to(device)
|
scale = scale.to(device)
|
||||||
|
|
||||||
@@ -169,7 +400,15 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|||||||
|
|
||||||
|
|
||||||
def merge(args):
|
def merge(args):
|
||||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
assert len(args.models) == len(
|
||||||
|
args.ratios
|
||||||
|
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||||
|
if args.lbws:
|
||||||
|
assert len(args.models) == len(
|
||||||
|
args.lbws
|
||||||
|
), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください"
|
||||||
|
else:
|
||||||
|
args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく
|
||||||
|
|
||||||
def str_to_dtype(p):
|
def str_to_dtype(p):
|
||||||
if p == "float":
|
if p == "float":
|
||||||
@@ -187,7 +426,7 @@ def merge(args):
|
|||||||
|
|
||||||
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
|
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
|
||||||
state_dict, metadata, v2, base_model = merge_lora_models(
|
state_dict, metadata, v2, base_model = merge_lora_models(
|
||||||
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
|
args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"calculating hashes and creating metadata...")
|
logger.info(f"calculating hashes and creating metadata...")
|
||||||
@@ -231,12 +470,19 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
|
help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
"--save_to",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
|
"--models",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
|
||||||
)
|
)
|
||||||
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||||
|
parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率")
|
||||||
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--new_conv_rank",
|
"--new_conv_rank",
|
||||||
@@ -244,7 +490,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
default=None,
|
default=None,
|
||||||
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
|
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
|
||||||
)
|
)
|
||||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
parser.add_argument(
|
||||||
|
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no_metadata",
|
"--no_metadata",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user