From 214ed092f2208caa5636bb631e7f37ab97c67a3f Mon Sep 17 00:00:00 2001 From: mgz-dev <49577754+mgz-dev@users.noreply.github.com> Date: Sat, 4 Mar 2023 02:01:10 -0600 Subject: [PATCH] add support to extract lora with resnet and 2d blocks Modified resize script so support different types of LoRA networks (refer to Kohaku-Blueleaf module implementation for structure). --- networks/resize_lora.py | 109 +++++++++++++++++++++++++++------------- 1 file changed, 74 insertions(+), 35 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index eb745333..77d79d9f 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -59,6 +59,72 @@ def index_sv_fro(S, target): return index +# Modified from Kohaku-blueleaf's extract/merge functions +def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): + out_size, in_size, kernel_size, _ = weight.size() + U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) + + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) + + lora_rank = param_dict["new_rank"] + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu() + del U, S, Vh, weight + return param_dict + + +def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): + out_size, in_size = weight.size() + + U, S, Vh = torch.linalg.svd(weight.to(device)) + + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) + lora_rank = param_dict["new_rank"] + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() + del U, S, Vh, weight + return param_dict + + +def merge_conv(lora_down, lora_up, device): + in_rank, in_size, kernel_size, k_ = lora_down.shape + out_size, out_rank, _, _ = lora_up.shape + assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch" + + lora_down = lora_down.to(device) + lora_up = lora_up.to(device) + + merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1) + weight = merged.reshape(out_size, in_size, kernel_size, kernel_size) + del lora_up, lora_down + return weight + + +def merge_linear(lora_down, lora_up, device): + in_rank, in_size = lora_down.shape + out_size, out_rank = lora_up.shape + assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch" + + lora_down = lora_down.to(device) + lora_up = lora_up.to(device) + + weight = lora_up @ lora_down + del lora_up, lora_down + return weight + + def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): param_dict = {} @@ -147,20 +213,11 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn conv2d = (len(lora_down_weight.size()) == 4) if conv2d: - lora_down_weight = lora_down_weight.squeeze() - lora_up_weight = lora_up_weight.squeeze() - - if device: - org_device = lora_up_weight.device - lora_up_weight = lora_up_weight.to(args.device) - lora_down_weight = lora_down_weight.to(args.device) - - full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight) - - U, S, Vh = torch.linalg.svd(full_weight_matrix) - - - param_dict = rank_resize(S, new_rank, dynamic_method, dynamic_param, scale) + full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) + param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + else: + full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) + param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) new_rank = param_dict['new_rank'] new_alpha = param_dict['new_alpha'] @@ -181,28 +238,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn else: verbose_str+=f"\n" - U = U[:, :new_rank] - S = S[:new_rank] - U = U @ torch.diag(S) - Vh = Vh[:new_rank, :] - - # dist = torch.cat([U.flatten(), Vh.flatten()]) - # hi_val = torch.quantile(dist, CLAMP_QUANTILE) - # low_val = -hi_val - # U = U.clamp(low_val, hi_val) - # Vh = Vh.clamp(low_val, hi_val) - - if conv2d: - U = U.unsqueeze(2).unsqueeze(3) - Vh = Vh.unsqueeze(2).unsqueeze(3) - - if device: - U = U.to(org_device) - Vh = Vh.to(org_device) - - o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous() + o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype) block_down_name = None @@ -210,6 +248,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn lora_down_weight = None lora_up_weight = None weights_loaded = False + del param_dict if verbose: print(verbose_str)