mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 14:45:19 +00:00
add support to extract lora with resnet and 2d blocks
Modified resize script so support different types of LoRA networks (refer to Kohaku-Blueleaf module implementation for structure).
This commit is contained in:
@@ -59,6 +59,72 @@ def index_sv_fro(S, target):
|
|||||||
return index
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from Kohaku-blueleaf's extract/merge functions
|
||||||
|
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||||
|
out_size, in_size, kernel_size, _ = weight.size()
|
||||||
|
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
||||||
|
|
||||||
|
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||||
|
|
||||||
|
lora_rank = param_dict["new_rank"]
|
||||||
|
|
||||||
|
U = U[:, :lora_rank]
|
||||||
|
S = S[:lora_rank]
|
||||||
|
U = U @ torch.diag(S)
|
||||||
|
Vh = Vh[:lora_rank, :]
|
||||||
|
|
||||||
|
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
|
||||||
|
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
|
||||||
|
del U, S, Vh, weight
|
||||||
|
return param_dict
|
||||||
|
|
||||||
|
|
||||||
|
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||||
|
out_size, in_size = weight.size()
|
||||||
|
|
||||||
|
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||||
|
|
||||||
|
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||||
|
lora_rank = param_dict["new_rank"]
|
||||||
|
|
||||||
|
U = U[:, :lora_rank]
|
||||||
|
S = S[:lora_rank]
|
||||||
|
U = U @ torch.diag(S)
|
||||||
|
Vh = Vh[:lora_rank, :]
|
||||||
|
|
||||||
|
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
|
||||||
|
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
|
||||||
|
del U, S, Vh, weight
|
||||||
|
return param_dict
|
||||||
|
|
||||||
|
|
||||||
|
def merge_conv(lora_down, lora_up, device):
|
||||||
|
in_rank, in_size, kernel_size, k_ = lora_down.shape
|
||||||
|
out_size, out_rank, _, _ = lora_up.shape
|
||||||
|
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
|
||||||
|
|
||||||
|
lora_down = lora_down.to(device)
|
||||||
|
lora_up = lora_up.to(device)
|
||||||
|
|
||||||
|
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
|
||||||
|
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
|
||||||
|
del lora_up, lora_down
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
def merge_linear(lora_down, lora_up, device):
|
||||||
|
in_rank, in_size = lora_down.shape
|
||||||
|
out_size, out_rank = lora_up.shape
|
||||||
|
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
|
||||||
|
|
||||||
|
lora_down = lora_down.to(device)
|
||||||
|
lora_up = lora_up.to(device)
|
||||||
|
|
||||||
|
weight = lora_up @ lora_down
|
||||||
|
del lora_up, lora_down
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||||
param_dict = {}
|
param_dict = {}
|
||||||
|
|
||||||
@@ -147,20 +213,11 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
conv2d = (len(lora_down_weight.size()) == 4)
|
conv2d = (len(lora_down_weight.size()) == 4)
|
||||||
|
|
||||||
if conv2d:
|
if conv2d:
|
||||||
lora_down_weight = lora_down_weight.squeeze()
|
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||||
lora_up_weight = lora_up_weight.squeeze()
|
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||||
|
else:
|
||||||
if device:
|
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
||||||
org_device = lora_up_weight.device
|
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||||
lora_up_weight = lora_up_weight.to(args.device)
|
|
||||||
lora_down_weight = lora_down_weight.to(args.device)
|
|
||||||
|
|
||||||
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
|
|
||||||
|
|
||||||
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
|
||||||
|
|
||||||
|
|
||||||
param_dict = rank_resize(S, new_rank, dynamic_method, dynamic_param, scale)
|
|
||||||
|
|
||||||
new_rank = param_dict['new_rank']
|
new_rank = param_dict['new_rank']
|
||||||
new_alpha = param_dict['new_alpha']
|
new_alpha = param_dict['new_alpha']
|
||||||
@@ -181,28 +238,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
else:
|
else:
|
||||||
verbose_str+=f"\n"
|
verbose_str+=f"\n"
|
||||||
|
|
||||||
U = U[:, :new_rank]
|
|
||||||
S = S[:new_rank]
|
|
||||||
U = U @ torch.diag(S)
|
|
||||||
|
|
||||||
Vh = Vh[:new_rank, :]
|
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||||
|
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
|
||||||
# dist = torch.cat([U.flatten(), Vh.flatten()])
|
|
||||||
# hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
|
||||||
# low_val = -hi_val
|
|
||||||
# U = U.clamp(low_val, hi_val)
|
|
||||||
# Vh = Vh.clamp(low_val, hi_val)
|
|
||||||
|
|
||||||
if conv2d:
|
|
||||||
U = U.unsqueeze(2).unsqueeze(3)
|
|
||||||
Vh = Vh.unsqueeze(2).unsqueeze(3)
|
|
||||||
|
|
||||||
if device:
|
|
||||||
U = U.to(org_device)
|
|
||||||
Vh = Vh.to(org_device)
|
|
||||||
|
|
||||||
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
|
|
||||||
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
|
|
||||||
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
|
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
|
||||||
|
|
||||||
block_down_name = None
|
block_down_name = None
|
||||||
@@ -210,6 +248,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
lora_down_weight = None
|
lora_down_weight = None
|
||||||
lora_up_weight = None
|
lora_up_weight = None
|
||||||
weights_loaded = False
|
weights_loaded = False
|
||||||
|
del param_dict
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(verbose_str)
|
print(verbose_str)
|
||||||
|
|||||||
Reference in New Issue
Block a user