Merge pull request #2296 from kohya-ss/add-svd-lowrank-niter

Add --svd_lowrank_niter option to resize_lora.py
This commit is contained in:
Kohya S.
2026-03-29 20:40:55 +09:00
committed by GitHub

View File

@@ -85,13 +85,13 @@ def index_sv_ratio(S, target):
# Modified from Kohaku-blueleaf's extract/merge functions
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2):
out_size, in_size, kernel_size, _ = weight.size()
weight = weight.reshape(out_size, -1)
_in_size = in_size * kernel_size * kernel_size
if out_size > 2048 and _in_size > 2048:
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size))
if svd_lowrank_niter > 0 and out_size > 2048 and _in_size > 2048:
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size), niter=svd_lowrank_niter)
Vh = V.T
else:
U, S, Vh = torch.linalg.svd(weight.to(device))
@@ -110,11 +110,11 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
return param_dict
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2):
out_size, in_size = weight.size()
if out_size > 2048 and in_size > 2048:
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size))
if svd_lowrank_niter > 0 and out_size > 2048 and in_size > 2048:
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size), niter=svd_lowrank_niter)
Vh = V.T
else:
U, S, Vh = torch.linalg.svd(weight.to(device))
@@ -209,7 +209,7 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
return param_dict
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose, svd_lowrank_niter=2):
max_old_rank = None
new_alpha = None
verbose_str = "\n"
@@ -273,10 +273,10 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
if conv2d:
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale)
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter)
else:
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter)
if verbose:
max_ratio = param_dict["max_ratio"]
@@ -347,7 +347,7 @@ def resize(args):
logger.info("Resizing Lora...")
state_dict, old_dim, new_alpha = resize_lora_model(
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose, args.svd_lowrank_niter
)
# update metadata
@@ -425,6 +425,13 @@ def setup_parser() -> argparse.ArgumentParser:
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank",
)
parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction")
parser.add_argument(
"--svd_lowrank_niter",
type=int,
default=2,
help="Number of iterations for svd_lowrank on large matrices (>2048 dims). 0 to disable and use full SVD"
" / 大行列(2048次元超)に対するsvd_lowrankの反復回数。0で無効化し完全SVDを使用",
)
return parser