format by black

This commit is contained in:
Kohya S
2024-09-13 21:26:06 +09:00
parent 734d2e5b2b
commit f4a0bea6dc

View File

@@ -11,8 +11,10 @@ from library import sai_model_spec, train_util
import library.model_util as model_util import library.model_util as model_util
import lora import lora
from library.utils import setup_logging from library.utils import setup_logging
setup_logging() setup_logging()
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CLAMP_QUANTILE = 0.99 CLAMP_QUANTILE = 0.99
@@ -22,38 +24,118 @@ SDXL_LAYER_NUM = [12, 20]
LAYER12 = { LAYER12 = {
"BASE": True, "BASE": True,
"IN00": False, "IN01": False, "IN02": False, "IN03": False, "IN04": True, "IN05": True, "IN00": False,
"IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, "IN01": False,
"IN02": False,
"IN03": False,
"IN04": True,
"IN05": True,
"IN06": False,
"IN07": True,
"IN08": True,
"IN09": False,
"IN10": False,
"IN11": False,
"MID": True, "MID": True,
"OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, "OUT00": True,
"OUT06": False, "OUT07": False, "OUT08": False, "OUT09": False, "OUT10": False, "OUT11": False "OUT01": True,
"OUT02": True,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": False,
"OUT07": False,
"OUT08": False,
"OUT09": False,
"OUT10": False,
"OUT11": False,
} }
LAYER17 = { LAYER17 = {
"BASE": True, "BASE": True,
"IN00": False, "IN01": True, "IN02": True, "IN03": False, "IN04": True, "IN05": True, "IN00": False,
"IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, "IN01": True,
"IN02": True,
"IN03": False,
"IN04": True,
"IN05": True,
"IN06": False,
"IN07": True,
"IN08": True,
"IN09": False,
"IN10": False,
"IN11": False,
"MID": True, "MID": True,
"OUT00": False, "OUT01": False, "OUT02": False, "OUT03": True, "OUT04": True, "OUT05": True, "OUT00": False,
"OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, "OUT01": False,
"OUT02": False,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": True,
"OUT07": True,
"OUT08": True,
"OUT09": True,
"OUT10": True,
"OUT11": True,
} }
LAYER20 = { LAYER20 = {
"BASE": True, "BASE": True,
"IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, "IN00": True,
"IN06": True, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, "IN01": True,
"IN02": True,
"IN03": True,
"IN04": True,
"IN05": True,
"IN06": True,
"IN07": True,
"IN08": True,
"IN09": False,
"IN10": False,
"IN11": False,
"MID": True, "MID": True,
"OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, "OUT00": True,
"OUT06": True, "OUT07": True, "OUT08": True, "OUT09": False, "OUT10": False, "OUT11": False, "OUT01": True,
"OUT02": True,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": True,
"OUT07": True,
"OUT08": True,
"OUT09": False,
"OUT10": False,
"OUT11": False,
} }
LAYER26 = { LAYER26 = {
"BASE": True, "BASE": True,
"IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, "IN00": True,
"IN06": True, "IN07": True, "IN08": True, "IN09": True, "IN10": True, "IN11": True, "IN01": True,
"IN02": True,
"IN03": True,
"IN04": True,
"IN05": True,
"IN06": True,
"IN07": True,
"IN08": True,
"IN09": True,
"IN10": True,
"IN11": True,
"MID": True, "MID": True,
"OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, "OUT00": True,
"OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, "OUT01": True,
"OUT02": True,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": True,
"OUT07": True,
"OUT08": True,
"OUT09": True,
"OUT10": True,
"OUT11": True,
} }
assert len([v for v in LAYER12.values() if v]) == 12 assert len([v for v in LAYER12.values() if v]) == 12
@@ -145,13 +227,7 @@ def save_to_file(file_name, state_dict, dtype, metadata):
torch.save(state_dict, file_name) torch.save(state_dict, file_name)
def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype): def format_lbws(lbws):
logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
merged_sd = {}
v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2
base_model = None
if lbws:
try: try:
# lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している # lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している
lbws = [json.loads(lbw) for lbw in lbws] lbws = [json.loads(lbw) for lbw in lbws]
@@ -159,8 +235,12 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください")
assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください" assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください"
assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください" assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください"
assert all(len(lbw) in ACCEPTABLE for lbw in lbws), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください" assert all(
assert all(all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" len(lbw) in ACCEPTABLE for lbw in lbws
), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください"
assert all(
all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws
), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください"
layer_num = len(lbws[0]) layer_num = len(lbws[0])
is_sdxl = True if layer_num in SDXL_LAYER_NUM else False is_sdxl = True if layer_num in SDXL_LAYER_NUM else False
@@ -171,6 +251,20 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
"26": LAYER26.values(), "26": LAYER26.values(),
}[str(layer_num)] }[str(layer_num)]
LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag] LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag]
return lbws, is_sdxl, LBW_TARGET_IDX
def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype):
logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
merged_sd = {}
v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2
base_model = None
if lbws:
lbws, is_sdxl, LBW_TARGET_IDX = format_lbws(lbws)
else:
is_sdxl = False
LBW_TARGET_IDX = []
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
logger.info(f"loading: {model}") logger.info(f"loading: {model}")
@@ -186,7 +280,7 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
lbw_weights = [1] * 26 lbw_weights = [1] * 26
for index, value in zip(LBW_TARGET_IDX, lbw): for index, value in zip(LBW_TARGET_IDX, lbw):
lbw_weights[index] = value lbw_weights[index] = value
print(dict(zip(LAYER26.keys(), lbw_weights))) logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
# merge # merge
logger.info(f"merging...") logger.info(f"merging...")
@@ -306,9 +400,13 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
def merge(args): def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" assert len(args.models) == len(
args.ratios
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
if args.lbws: if args.lbws:
assert len(args.models) == len(args.lbws), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" assert len(args.models) == len(
args.lbws
), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください"
else: else:
args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく
@@ -372,10 +470,16 @@ def setup_parser() -> argparse.ArgumentParser:
help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨", help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨",
) )
parser.add_argument( parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" "--save_to",
type=str,
default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
) )
parser.add_argument( parser.add_argument(
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" "--models",
type=str,
nargs="*",
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
) )
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率") parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率")
@@ -386,7 +490,9 @@ def setup_parser() -> argparse.ArgumentParser:
default=None, default=None,
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ", help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
) )
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument(
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
)
parser.add_argument( parser.add_argument(
"--no_metadata", "--no_metadata",
action="store_true", action="store_true",