From f4a0bea6dce152e2210f611f94acfdfaa72068fe Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 21:26:06 +0900 Subject: [PATCH] format by black --- networks/svd_merge_lora.py | 188 +++++++++++++++++++++++++++++-------- 1 file changed, 147 insertions(+), 41 deletions(-) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 6e163aec..0decd904 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -11,8 +11,10 @@ from library import sai_model_spec, train_util import library.model_util as model_util import lora from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) CLAMP_QUANTILE = 0.99 @@ -22,38 +24,118 @@ SDXL_LAYER_NUM = [12, 20] LAYER12 = { "BASE": True, - "IN00": False, "IN01": False, "IN02": False, "IN03": False, "IN04": True, "IN05": True, - "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "IN00": False, + "IN01": False, + "IN02": False, + "IN03": False, + "IN04": True, + "IN05": True, + "IN06": False, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, "MID": True, - "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, - "OUT06": False, "OUT07": False, "OUT08": False, "OUT09": False, "OUT10": False, "OUT11": False + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": False, + "OUT07": False, + "OUT08": False, + "OUT09": False, + "OUT10": False, + "OUT11": False, } LAYER17 = { "BASE": True, - "IN00": False, "IN01": True, "IN02": True, "IN03": False, "IN04": True, "IN05": True, - "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "IN00": False, + "IN01": True, + "IN02": True, + "IN03": False, + "IN04": True, + "IN05": True, + "IN06": False, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, "MID": True, - "OUT00": False, "OUT01": False, "OUT02": False, "OUT03": True, "OUT04": True, "OUT05": True, - "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, + "OUT00": False, + "OUT01": False, + "OUT02": False, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": True, + "OUT10": True, + "OUT11": True, } LAYER20 = { "BASE": True, - "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, - "IN06": True, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "IN00": True, + "IN01": True, + "IN02": True, + "IN03": True, + "IN04": True, + "IN05": True, + "IN06": True, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, "MID": True, - "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, - "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": False, "OUT10": False, "OUT11": False, + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": False, + "OUT10": False, + "OUT11": False, } LAYER26 = { "BASE": True, - "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, - "IN06": True, "IN07": True, "IN08": True, "IN09": True, "IN10": True, "IN11": True, + "IN00": True, + "IN01": True, + "IN02": True, + "IN03": True, + "IN04": True, + "IN05": True, + "IN06": True, + "IN07": True, + "IN08": True, + "IN09": True, + "IN10": True, + "IN11": True, "MID": True, - "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, - "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": True, + "OUT10": True, + "OUT11": True, } assert len([v for v in LAYER12.values() if v]) == 12 @@ -145,6 +227,33 @@ def save_to_file(file_name, state_dict, dtype, metadata): torch.save(state_dict, file_name) +def format_lbws(lbws): + try: + # lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している + lbws = [json.loads(lbw) for lbw in lbws] + except Exception: + raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") + assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください" + assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください" + assert all( + len(lbw) in ACCEPTABLE for lbw in lbws + ), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください" + assert all( + all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws + ), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" + + layer_num = len(lbws[0]) + is_sdxl = True if layer_num in SDXL_LAYER_NUM else False + FLAGS = { + "12": LAYER12.values(), + "17": LAYER17.values(), + "20": LAYER20.values(), + "26": LAYER26.values(), + }[str(layer_num)] + LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag] + return lbws, is_sdxl, LBW_TARGET_IDX + + def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype): logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") merged_sd = {} @@ -152,25 +261,10 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer base_model = None if lbws: - try: - # lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している - lbws = [json.loads(lbw) for lbw in lbws] - except Exception: - raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") - assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください" - assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください" - assert all(len(lbw) in ACCEPTABLE for lbw in lbws), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください" - assert all(all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" - - layer_num = len(lbws[0]) - is_sdxl = True if layer_num in SDXL_LAYER_NUM else False - FLAGS = { - "12": LAYER12.values(), - "17": LAYER17.values(), - "20": LAYER20.values(), - "26": LAYER26.values(), - }[str(layer_num)] - LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag] + lbws, is_sdxl, LBW_TARGET_IDX = format_lbws(lbws) + else: + is_sdxl = False + LBW_TARGET_IDX = [] for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") @@ -186,7 +280,7 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer lbw_weights = [1] * 26 for index, value in zip(LBW_TARGET_IDX, lbw): lbw_weights[index] = value - print(dict(zip(LAYER26.keys(), lbw_weights))) + logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}") # merge logger.info(f"merging...") @@ -306,9 +400,13 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert len(args.models) == len( + args.ratios + ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" if args.lbws: - assert len(args.models) == len(args.lbws), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" + assert len(args.models) == len( + args.lbws + ), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" else: args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく @@ -372,10 +470,16 @@ def setup_parser() -> argparse.ArgumentParser: help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", ) parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", ) parser.add_argument( - "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + "--models", + type=str, + nargs="*", + help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors", ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率") @@ -386,7 +490,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ", ) - parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) parser.add_argument( "--no_metadata", action="store_true",