From d2da3c42363a65e2a0eca79c7d4f2c439a398157 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 10 Feb 2023 22:54:35 +0900 Subject: [PATCH 1/5] support for models with different alphas --- networks/merge_lora.py | 79 +++++++++++----- networks/merge_lora_old.py | 179 +++++++++++++++++++++++++++++++++++++ 2 files changed, 235 insertions(+), 23 deletions(-) create mode 100644 networks/merge_lora_old.py diff --git a/networks/merge_lora.py b/networks/merge_lora.py index 1d4cb3b5..09aea7b2 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -1,5 +1,5 @@ - +import math import argparse import os import torch @@ -85,43 +85,76 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): weight = weight + ratio * (up_weight @ down_weight) * scale else: # conv2d - weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale + weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) + ).unsqueeze(2).unsqueeze(3) * scale module.weight = torch.nn.Parameter(weight) def merge_lora_models(models, ratios, merge_dtype): - merged_sd = {} + base_alphas = {} # alpha for merged model + base_dims = {} - alpha = None - dim = None + merged_sd = {} for model, ratio in zip(models, ratios): print(f"loading: {model}") lora_sd = load_state_dict(model, merge_dtype) + # get alpha and dim + alphas = {} # alpha for current model + dims = {} # dims for current model + for key in lora_sd.keys(): + if 'alpha' in key: + lora_module_name = key[:key.rfind(".alpha")] + alpha = float(lora_sd[key].detach().numpy()) + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + elif "lora_down" in key: + lora_module_name = key[:key.rfind(".lora_down")] + dim = lora_sd[key].size()[0] + dims[lora_module_name] = dim + if lora_module_name not in base_dims: + base_dims[lora_module_name] = dim + + for lora_module_name in dims.keys(): + if lora_module_name not in alphas: + alpha = dims[lora_module_name] + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + + print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + + # merge print(f"merging...") for key in lora_sd.keys(): if 'alpha' in key: - if key in merged_sd: - assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません" - else: - alpha = lora_sd[key].detach().numpy() - merged_sd[key] = lora_sd[key] + continue + + lora_module_name = key[:key.rfind(".lora_")] + + base_alpha = base_alphas[lora_module_name] + alpha = alphas[lora_module_name] + + scale = math.sqrt(alpha / base_alpha) * ratio + + if key in merged_sd: + assert merged_sd[key].size() == lora_sd[key].size( + ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale else: - if key in merged_sd: - assert merged_sd[key].size() == lora_sd[key].size( - ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" - merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio - else: - if "lora_down" in key: - dim = lora_sd[key].size()[0] - merged_sd[key] = lora_sd[key] * ratio + merged_sd[key] = lora_sd[key] * scale + + # set alpha to sd + for lora_module_name, alpha in base_alphas.items(): + key = lora_module_name + ".alpha" + merged_sd[key] = torch.tensor(alpha) - print(f"dim (rank): {dim}, alpha: {alpha}") - if alpha is None: - alpha = dim + print("merged model") + print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") - return merged_sd, dim, alpha + return merged_sd def merge(args): @@ -152,7 +185,7 @@ def merge(args): model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae) else: - state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) + state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) print(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype) diff --git a/networks/merge_lora_old.py b/networks/merge_lora_old.py new file mode 100644 index 00000000..1d4cb3b5 --- /dev/null +++ b/networks/merge_lora_old.py @@ -0,0 +1,179 @@ + + +import argparse +import os +import torch +from safetensors.torch import load_file, save_file +import library.model_util as model_util +import lora + + +def load_state_dict(file_name, dtype): + if os.path.splitext(file_name)[1] == '.safetensors': + sd = load_file(file_name) + else: + sd = torch.load(file_name, map_location='cpu') + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + return sd + + +def save_to_file(file_name, model, state_dict, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == '.safetensors': + save_file(model, file_name) + else: + torch.save(model, file_name) + + +def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): + text_encoder.to(merge_dtype) + unet.to(merge_dtype) + + # create module map + name_to_module = {} + for i, root_module in enumerate([text_encoder, unet]): + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + else: + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): + lora_name = prefix + '.' + name + '.' + child_name + lora_name = lora_name.replace('.', '_') + name_to_module[lora_name] = child_module + + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) + + print(f"merging...") + for key in lora_sd.keys(): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[:key.index("lora_down")] + 'alpha' + + # find original module for this lora + module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + print(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # print(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + weight = module.weight + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + else: + # conv2d + weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale + + module.weight = torch.nn.Parameter(weight) + + +def merge_lora_models(models, ratios, merge_dtype): + merged_sd = {} + + alpha = None + dim = None + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) + + print(f"merging...") + for key in lora_sd.keys(): + if 'alpha' in key: + if key in merged_sd: + assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません" + else: + alpha = lora_sd[key].detach().numpy() + merged_sd[key] = lora_sd[key] + else: + if key in merged_sd: + assert merged_sd[key].size() == lora_sd[key].size( + ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" + merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio + else: + if "lora_down" in key: + dim = lora_sd[key].size()[0] + merged_sd[key] = lora_sd[key] * ratio + + print(f"dim (rank): {dim}, alpha: {alpha}") + if alpha is None: + alpha = dim + + return merged_sd, dim, alpha + + +def merge(args): + assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + + def str_to_dtype(p): + if p == 'float': + return torch.float + if p == 'fp16': + return torch.float16 + if p == 'bf16': + return torch.bfloat16 + return None + + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + if args.sd_model is not None: + print(f"loading SD model: {args.sd_model}") + + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) + + merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) + + print(f"saving SD model to: {args.save_to}") + model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, + args.sd_model, 0, 0, save_dtype, vae) + else: + state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) + + print(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") + parser.add_argument("--precision", type=str, default="float", + choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)") + parser.add_argument("--sd_model", type=str, default=None, + help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") + parser.add_argument("--save_to", type=str, default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") + parser.add_argument("--models", type=str, nargs='*', + help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") + parser.add_argument("--ratios", type=float, nargs='*', + help="ratios for each model / それぞれのLoRAモデルの比率") + + args = parser.parse_args() + merge(args) From c7406d6b27b20dfabdeef21f9df97adb2c3bf95f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 10 Feb 2023 22:55:00 +0900 Subject: [PATCH 2/5] keep metadata when resizing --- networks/resize_lora.py | 211 ++++++++++++++++++++++------------------ 1 file changed, 116 insertions(+), 95 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index e10d35bc..7beeb25e 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -5,148 +5,169 @@ import argparse import os import torch -from safetensors.torch import load_file, save_file +from safetensors.torch import load_file, save_file, safe_open from tqdm import tqdm +from library import train_util, model_util + def load_state_dict(file_name, dtype): - if os.path.splitext(file_name)[1] == '.safetensors': + if model_util.is_safetensors(file_name): sd = load_file(file_name) + with safe_open(file_name, framework="pt") as f: + metadata = f.metadata() else: sd = torch.load(file_name, map_location='cpu') + metadata = None + for key in list(sd.keys()): if type(sd[key]) == torch.Tensor: sd[key] = sd[key].to(dtype) - return sd + + return sd, metadata -def save_to_file(file_name, model, state_dict, dtype): +def save_to_file(file_name, model, state_dict, dtype, metadata): if dtype is not None: for key in list(state_dict.keys()): if type(state_dict[key]) == torch.Tensor: state_dict[key] = state_dict[key].to(dtype) - if os.path.splitext(file_name)[1] == '.safetensors': - save_file(model, file_name) + if model_util.is_safetensors(file_name): + save_file(model, file_name, metadata) else: torch.save(model, file_name) - -def resize_lora_model(model, new_rank, merge_dtype, save_dtype): - print("Loading Model...") - lora_sd = load_state_dict(model, merge_dtype) +def resize_lora_model(lora_sd, new_rank, save_dtype, device): + network_alpha = None + network_dim = None - network_alpha = None - network_dim = None + CLAMP_QUANTILE = 0.99 - CLAMP_QUANTILE = 0.99 + # Extract loaded lora dim and alpha + for key, value in lora_sd.items(): + if network_alpha is None and 'alpha' in key: + network_alpha = value + if network_dim is None and 'lora_down' in key and len(value.size()) == 2: + network_dim = value.size()[0] + if network_alpha is not None and network_dim is not None: + break + if network_alpha is None: + network_alpha = network_dim - # Extract loaded lora dim and alpha - for key, value in lora_sd.items(): - if network_alpha is None and 'alpha' in key: - network_alpha = value - if network_dim is None and 'lora_down' in key and len(value.size()) == 2: - network_dim = value.size()[0] - if network_alpha is not None and network_dim is not None: - break - if network_alpha is None: - network_alpha = network_dim + scale = network_alpha/network_dim + new_alpha = float(scale*new_rank) # calculate new alpha from scale - scale = network_alpha/network_dim - new_alpha = float(scale*new_rank) # calculate new alpha from scale + print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}") - print(f"dimension: {network_dim}, alpha: {network_alpha}, new alpha: {new_alpha}") + lora_down_weight = None + lora_up_weight = None - lora_down_weight = None - lora_up_weight = None + o_lora_sd = lora_sd.copy() + block_down_name = None + block_up_name = None - o_lora_sd = lora_sd.copy() - block_down_name = None - block_up_name = None + print("resizing lora...") + with torch.no_grad(): + for key, value in tqdm(lora_sd.items()): + if 'lora_down' in key: + block_down_name = key.split(".")[0] + lora_down_weight = value + if 'lora_up' in key: + block_up_name = key.split(".")[0] + lora_up_weight = value - print("resizing lora...") - with torch.no_grad(): - for key, value in tqdm(lora_sd.items()): - if 'lora_down' in key: - block_down_name = key.split(".")[0] - lora_down_weight = value - if 'lora_up' in key: - block_up_name = key.split(".")[0] - lora_up_weight = value + weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) - weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) + if (block_down_name == block_up_name) and weights_loaded: - if (block_down_name == block_up_name) and weights_loaded: + conv2d = (len(lora_down_weight.size()) == 4) - conv2d = (len(lora_down_weight.size()) == 4) - - if conv2d: - lora_down_weight = lora_down_weight.squeeze() - lora_up_weight = lora_up_weight.squeeze() + if conv2d: + lora_down_weight = lora_down_weight.squeeze() + lora_up_weight = lora_up_weight.squeeze() - if args.device: - org_device = lora_up_weight.device - lora_up_weight = lora_up_weight.to(args.device) - lora_down_weight = lora_down_weight.to(args.device) + if device: + org_device = lora_up_weight.device + lora_up_weight = lora_up_weight.to(args.device) + lora_down_weight = lora_down_weight.to(args.device) - full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight) + full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight) - U, S, Vh = torch.linalg.svd(full_weight_matrix) + U, S, Vh = torch.linalg.svd(full_weight_matrix) - U = U[:, :new_rank] - S = S[:new_rank] - U = U @ torch.diag(S) + U = U[:, :new_rank] + S = S[:new_rank] + U = U @ torch.diag(S) - Vh = Vh[:new_rank, :] + Vh = Vh[:new_rank, :] - dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, CLAMP_QUANTILE) - low_val = -hi_val + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, CLAMP_QUANTILE) + low_val = -hi_val - U = U.clamp(low_val, hi_val) - Vh = Vh.clamp(low_val, hi_val) - - if conv2d: - U = U.unsqueeze(2).unsqueeze(3) - Vh = Vh.unsqueeze(2).unsqueeze(3) - - if args.device: - U = U.to(org_device) - Vh = Vh.to(org_device) + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) - o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype) + if conv2d: + U = U.unsqueeze(2).unsqueeze(3) + Vh = Vh.unsqueeze(2).unsqueeze(3) - block_down_name = None - block_up_name = None - lora_down_weight = None - lora_up_weight = None - weights_loaded = False + if args.device: + U = U.to(org_device) + Vh = Vh.to(org_device) + + o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype) + + block_down_name = None + block_up_name = None + lora_down_weight = None + lora_up_weight = None + weights_loaded = False + + print("resizing complete") + return o_lora_sd, network_dim, new_alpha - print("resizing complete") - return o_lora_sd def resize(args): - def str_to_dtype(p): - if p == 'float': - return torch.float - if p == 'fp16': - return torch.float16 - if p == 'bf16': - return torch.bfloat16 - return None + def str_to_dtype(p): + if p == 'float': + return torch.float + if p == 'fp16': + return torch.float16 + if p == 'bf16': + return torch.bfloat16 + return None - merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 - save_dtype = str_to_dtype(args.save_precision) - if save_dtype is None: - save_dtype = merge_dtype + merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype - state_dict = resize_lora_model(args.model, args.new_rank, merge_dtype, save_dtype) + print("loading Model...") + lora_sd, metadata = load_state_dict(args.model, merge_dtype) - print(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype) + print("resizing rank...") + state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device) + + # update metadata + if metadata is None: + metadata = {} + + comment = metadata.get("ss_training_comment", "") + metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" + metadata["ss_network_dim"] = str(args.new_rank) + metadata["ss_network_alpha"] = str(new_alpha) + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + print(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) if __name__ == '__main__': From e5cc64a563055e4a61c34247b846ecf3abc01c90 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 10 Feb 2023 22:55:21 +0900 Subject: [PATCH 3/5] support multibyte characters for filename --- tools/resize_images_to_resolution.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py index 0876a4d3..c98cc889 100644 --- a/tools/resize_images_to_resolution.py +++ b/tools/resize_images_to_resolution.py @@ -4,6 +4,8 @@ import cv2 import argparse import shutil import math +from PIL import Image +import numpy as np def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False): @@ -35,7 +37,11 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi continue # Load image - img = cv2.imread(os.path.join(src_img_folder, filename)) + # img = cv2.imread(os.path.join(src_img_folder, filename)) + image = Image.open(os.path.join(src_img_folder, filename)) + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) base, _ = os.path.splitext(filename) for max_resolution in max_resolutions: @@ -72,7 +78,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg') # Save resized image in dst_img_folder - cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100]) + # cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100]) + image = Image.fromarray(img) + image.save(os.path.join(dst_img_folder, new_filename), quality=100) + proc = "Resized" if current_pixels > max_pixels else "Saved" print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") From 8d86f581743b3ea183ed0e6c0dd35d434d90774a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 10 Feb 2023 22:55:33 +0900 Subject: [PATCH 4/5] add merge script with svd --- networks/svd_merge_lora.py | 164 +++++++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 networks/svd_merge_lora.py diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py new file mode 100644 index 00000000..c0448fcb --- /dev/null +++ b/networks/svd_merge_lora.py @@ -0,0 +1,164 @@ + +import math +import argparse +import os +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm +import library.model_util as model_util +import lora + + +CLAMP_QUANTILE = 0.99 + + +def load_state_dict(file_name, dtype): + if os.path.splitext(file_name)[1] == '.safetensors': + sd = load_file(file_name) + else: + sd = torch.load(file_name, map_location='cpu') + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + return sd + + +def save_to_file(file_name, model, state_dict, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == '.safetensors': + save_file(model, file_name) + else: + torch.save(model, file_name) + + +def merge_lora_models(models, ratios, new_rank, device, merge_dtype): + merged_sd = {} + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) + + # merge + print(f"merging...") + for key in tqdm(list(lora_sd.keys())): + if 'lora_down' not in key: + continue + + lora_module_name = key[:key.rfind(".lora_down")] + + down_weight = lora_sd[key] + network_dim = down_weight.size()[0] + + up_weight = lora_sd[lora_module_name + '.lora_up.weight'] + alpha = lora_sd.get(lora_module_name + '.alpha', network_dim) + + in_dim = down_weight.size()[1] + out_dim = up_weight.size()[0] + conv2d = len(down_weight.size()) == 4 + print(lora_module_name, network_dim, alpha, in_dim, out_dim) + + # make original weight if not exist + if lora_module_name not in merged_sd: + weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype) + if device: + weight = weight.to(device) + else: + weight = merged_sd[lora_module_name] + + # merge to weight + if device: + up_weight = up_weight.to(device) + down_weight = down_weight.to(device) + + # W <- W + U * D + scale = (alpha / network_dim) + if not conv2d: # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + else: + weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) + ).unsqueeze(2).unsqueeze(3) * scale + + merged_sd[lora_module_name] = weight + + # extract from merged weights + print("extract new lora...") + merged_lora_sd = {} + with torch.no_grad(): + for lora_module_name, mat in tqdm(list(merged_sd.items())): + conv2d = (len(mat.size()) == 4) + if conv2d: + mat = mat.squeeze() + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :new_rank] + S = S[:new_rank] + U = U @ torch.diag(S) + + Vh = Vh[:new_rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, CLAMP_QUANTILE) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + up_weight = U + down_weight = Vh + + if conv2d: + up_weight = up_weight.unsqueeze(2).unsqueeze(3) + down_weight = down_weight.unsqueeze(2).unsqueeze(3) + + merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous() + merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous() + merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank) + + return merged_lora_sd + + +def merge(args): + assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + + def str_to_dtype(p): + if p == 'float': + return torch.float + if p == 'fp16': + return torch.float16 + if p == 'bf16': + return torch.bfloat16 + return None + + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, args.device, merge_dtype) + + print(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") + parser.add_argument("--precision", type=str, default="float", + choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)") + parser.add_argument("--save_to", type=str, default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") + parser.add_argument("--models", type=str, nargs='*', + help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") + parser.add_argument("--ratios", type=float, nargs='*', + help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument("--new_rank", type=int, default=4, + help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") + parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + + args = parser.parse_args() + merge(args) From 22e3aca89c95424261763c1affed367b24d5f807 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Fri, 10 Feb 2023 23:07:53 +0900 Subject: [PATCH 5/5] Update README.md --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index d3dd6129..e3aeb7d7 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,18 @@ Note: The LoRA models for SD 2.x is not supported too in Web UI. - 10 Feb. 2023, 2023/2/10: - Updated ``requirements.txt`` to prevent upgrading with pip taking a long time or failure to upgrade. + - ``resize_lora.py`` keeps the metadata of the model. ``dimension is resized from ...`` is added to the top of ``ss_training_comment``. + - ``merge_lora.py`` supports models with different ``alpha``s. If there is a problem, old version is ``merge_lora_old.py``. + - ``svd_merge_lora.py`` is added. This script merges LoRA models with any rank (dim) and alpha, and approximate a new LoRA with svd for a specified rank (dim). + - Note: merging scripts erase the metadata currently. + - ``resize_images_to_resolution.py`` supports multibyte characters in filenames. - pipでの更新が長時間掛かったり、更新に失敗したりするのを防ぐため、``requirements.txt``を更新しました。 + - ``resize_lora.py``がメタデータを保持するようになりました。 ``dimension is resized from ...`` という文字列が ``ss_training_comment`` の先頭に追加されます。 + - ``merge_lora.py``がalphaが異なるモデルをサポートしました。 何か問題がありましたら旧バージョン ``merge_lora_old.py`` をお使いください。 + - ``svd_merge_lora.py`` を追加しました。 複数の任意のdim (rank)、alphaのLoRAモデルをマージし、svdで任意dim(rank)のLoRAで近似します。 + - 注:マージ系のスクリプトは現時点ではメタデータを消去しますのでご注意ください。 + - ``resize_images_to_resolution.py``が日本語ファイル名をサポートしました。 + - 9 Feb. 2023, 2023/2/9: - Caption dropout is supported in ``train_db.py``, ``fine_tune.py`` and ``train_network.py``. Thanks to forestsource! - ``--caption_dropout_rate`` option specifies the dropout rate for captions (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the image is trained with the empty caption. Default is 0 (no dropout).