mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add support to extract lora with resnet and 2d blocks
Modified resize script so support different types of LoRA networks (refer to Kohaku-Blueleaf module implementation for structure).
This commit is contained in:
@@ -59,6 +59,72 @@ def index_sv_fro(S, target):
|
||||
return index
|
||||
|
||||
|
||||
# Modified from Kohaku-blueleaf's extract/merge functions
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
out_size, in_size, kernel_size, _ = weight.size()
|
||||
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
|
||||
lora_rank = param_dict["new_rank"]
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
|
||||
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
|
||||
del U, S, Vh, weight
|
||||
return param_dict
|
||||
|
||||
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
out_size, in_size = weight.size()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
|
||||
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
|
||||
del U, S, Vh, weight
|
||||
return param_dict
|
||||
|
||||
|
||||
def merge_conv(lora_down, lora_up, device):
|
||||
in_rank, in_size, kernel_size, k_ = lora_down.shape
|
||||
out_size, out_rank, _, _ = lora_up.shape
|
||||
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
|
||||
|
||||
lora_down = lora_down.to(device)
|
||||
lora_up = lora_up.to(device)
|
||||
|
||||
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
|
||||
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
|
||||
del lora_up, lora_down
|
||||
return weight
|
||||
|
||||
|
||||
def merge_linear(lora_down, lora_up, device):
|
||||
in_rank, in_size = lora_down.shape
|
||||
out_size, out_rank = lora_up.shape
|
||||
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
|
||||
|
||||
lora_down = lora_down.to(device)
|
||||
lora_up = lora_up.to(device)
|
||||
|
||||
weight = lora_up @ lora_down
|
||||
del lora_up, lora_down
|
||||
return weight
|
||||
|
||||
|
||||
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
param_dict = {}
|
||||
|
||||
@@ -147,20 +213,11 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
||||
conv2d = (len(lora_down_weight.size()) == 4)
|
||||
|
||||
if conv2d:
|
||||
lora_down_weight = lora_down_weight.squeeze()
|
||||
lora_up_weight = lora_up_weight.squeeze()
|
||||
|
||||
if device:
|
||||
org_device = lora_up_weight.device
|
||||
lora_up_weight = lora_up_weight.to(args.device)
|
||||
lora_down_weight = lora_down_weight.to(args.device)
|
||||
|
||||
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
|
||||
|
||||
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
||||
|
||||
|
||||
param_dict = rank_resize(S, new_rank, dynamic_method, dynamic_param, scale)
|
||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
else:
|
||||
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
|
||||
new_rank = param_dict['new_rank']
|
||||
new_alpha = param_dict['new_alpha']
|
||||
@@ -181,28 +238,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
||||
else:
|
||||
verbose_str+=f"\n"
|
||||
|
||||
U = U[:, :new_rank]
|
||||
S = S[:new_rank]
|
||||
U = U @ torch.diag(S)
|
||||
|
||||
Vh = Vh[:new_rank, :]
|
||||
|
||||
# dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
# hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
# low_val = -hi_val
|
||||
# U = U.clamp(low_val, hi_val)
|
||||
# Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
if conv2d:
|
||||
U = U.unsqueeze(2).unsqueeze(3)
|
||||
Vh = Vh.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
if device:
|
||||
U = U.to(org_device)
|
||||
Vh = Vh.to(org_device)
|
||||
|
||||
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
|
||||
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
|
||||
|
||||
block_down_name = None
|
||||
@@ -210,6 +248,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
weights_loaded = False
|
||||
del param_dict
|
||||
|
||||
if verbose:
|
||||
print(verbose_str)
|
||||
|
||||
Reference in New Issue
Block a user