From ddfe94b33bc47b7fbdbcff88de02e46a91b06fd9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 22 Jan 2023 21:33:35 +0900 Subject: [PATCH] Update for alpha value --- networks/extract_lora_from_models.py | 24 ++++++++++------- networks/merge_lora.py | 40 +++++++++++++++++++++------- 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 0a4c3a00..84d705cf 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -44,9 +44,9 @@ def svd(args): print(f"loading SD model : {args.model_tuned}") text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) - # create LoRA network to extract weights - lora_network_o = lora.create_network(1.0, args.dim, None, text_encoder_o, unet_o) - lora_network_t = lora.create_network(1.0, args.dim, None, text_encoder_t, unet_t) + # create LoRA network to extract weights: Use dim (rank) as alpha + lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o) + lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t) assert len(lora_network_o.text_encoder_loras) == len( lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " @@ -77,10 +77,10 @@ def svd(args): module_t = lora_t.org_module diff = module_t.weight - module_o.weight diff = diff.float() - + if args.device: diff = diff.to(args.device) - + diffs[lora_name] = diff # make LoRA with svd @@ -116,6 +116,9 @@ def svd(args): print(f"LoRA has {len(lora_sd)} weights.") for key in list(lora_sd.keys()): + if "alpha" in key: + continue + lora_name = key.split('.')[0] i = 0 if "lora_up" in key else 1 @@ -124,7 +127,7 @@ def svd(args): if len(lora_sd[key].size()) == 4: weights = weights.unsqueeze(2).unsqueeze(3) - assert weights.size() == lora_sd[key].size() + assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}" lora_sd[key] = weights # load state dict to LoRA and save it @@ -135,7 +138,10 @@ def svd(args): if dir_name and not os.path.exists(dir_name): os.makedirs(dir_name, exist_ok=True) - lora_network_o.save_weights(args.save_to, save_dtype, {}) + # minimum metadata + metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} + + lora_network_o.save_weights(args.save_to, save_dtype, metadata) print(f"LoRA weights are saved to: {args.save_to}") @@ -151,8 +157,8 @@ if __name__ == '__main__': help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors") parser.add_argument("--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") - parser.add_argument("--dim", type=int, default=4, help="dimension of LoRA (default 4) / LoRAの次元数(デフォルト4)") - parser.add_argument("--device", type=str, default=None, help="device to use, 'cuda' for GPU / 計算を行うデバイス、'cuda'でGPUを使う") + parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") + parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") args = parser.parse_args() svd(args) diff --git a/networks/merge_lora.py b/networks/merge_lora.py index d873a8ef..1d4cb3b5 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -61,6 +61,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): for key in lora_sd.keys(): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") + alpha_key = key[:key.index("lora_down")] + 'alpha' # find original module for this lora module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" @@ -73,14 +74,18 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): down_weight = lora_sd[key] up_weight = lora_sd[up_key] + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + # W <- W + U * D weight = module.weight if len(weight.size()) == 2: # linear - weight = weight + ratio * (up_weight @ down_weight) + weight = weight + ratio * (up_weight @ down_weight) * scale else: # conv2d - weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale module.weight = torch.nn.Parameter(weight) @@ -88,20 +93,35 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): def merge_lora_models(models, ratios, merge_dtype): merged_sd = {} + alpha = None + dim = None for model, ratio in zip(models, ratios): print(f"loading: {model}") lora_sd = load_state_dict(model, merge_dtype) print(f"merging...") for key in lora_sd.keys(): - if key in merged_sd: - assert merged_sd[key].size() == lora_sd[key].size( - ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" - merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio + if 'alpha' in key: + if key in merged_sd: + assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません" + else: + alpha = lora_sd[key].detach().numpy() + merged_sd[key] = lora_sd[key] else: - merged_sd[key] = lora_sd[key] * ratio + if key in merged_sd: + assert merged_sd[key].size() == lora_sd[key].size( + ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" + merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio + else: + if "lora_down" in key: + dim = lora_sd[key].size()[0] + merged_sd[key] = lora_sd[key] * ratio - return merged_sd + print(f"dim (rank): {dim}, alpha: {alpha}") + if alpha is None: + alpha = dim + + return merged_sd, dim, alpha def merge(args): @@ -132,7 +152,7 @@ def merge(args): model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae) else: - state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) + state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) print(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype) @@ -145,7 +165,7 @@ if __name__ == '__main__': parser.add_argument("--save_precision", type=str, default=None, choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") parser.add_argument("--precision", type=str, default="float", - choices=["float", "fp16", "bf16"], help="precision in merging / マージの計算時の精度") + choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)") parser.add_argument("--sd_model", type=str, default=None, help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") parser.add_argument("--save_to", type=str, default=None,