From 0e168dd1eb9eb683faff315f754cc7d1da36096a Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 29 Mar 2026 20:33:33 +0900 Subject: [PATCH] add --svd_lowrank_niter option to resize_lora.py Allow users to control the number of iterations for torch.svd_lowrank on large matrices. Default is 2 (matching PR #2240 behavior). Set to 0 to disable svd_lowrank and use full SVD instead. Co-Authored-By: Claude Opus 4.6 --- networks/resize_lora.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 5dd1132f..a616b6ac 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -85,13 +85,13 @@ def index_sv_ratio(S, target): # Modified from Kohaku-blueleaf's extract/merge functions -def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): +def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2): out_size, in_size, kernel_size, _ = weight.size() weight = weight.reshape(out_size, -1) _in_size = in_size * kernel_size * kernel_size - if out_size > 2048 and _in_size > 2048: - U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size)) + if svd_lowrank_niter > 0 and out_size > 2048 and _in_size > 2048: + U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size), niter=svd_lowrank_niter) Vh = V.T else: U, S, Vh = torch.linalg.svd(weight.to(device)) @@ -110,11 +110,11 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale return param_dict -def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): +def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2): out_size, in_size = weight.size() - if out_size > 2048 and in_size > 2048: - U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size)) + if svd_lowrank_niter > 0 and out_size > 2048 and in_size > 2048: + U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size), niter=svd_lowrank_niter) Vh = V.T else: U, S, Vh = torch.linalg.svd(weight.to(device)) @@ -209,7 +209,7 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): return param_dict -def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): +def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose, svd_lowrank_niter=2): max_old_rank = None new_alpha = None verbose_str = "\n" @@ -273,10 +273,10 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna if conv2d: full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) - param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale) + param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter) else: full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) - param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter) if verbose: max_ratio = param_dict["max_ratio"] @@ -347,7 +347,7 @@ def resize(args): logger.info("Resizing Lora...") state_dict, old_dim, new_alpha = resize_lora_model( - lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose + lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose, args.svd_lowrank_niter ) # update metadata @@ -425,6 +425,13 @@ def setup_parser() -> argparse.ArgumentParser: help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank", ) parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction") + parser.add_argument( + "--svd_lowrank_niter", + type=int, + default=2, + help="Number of iterations for svd_lowrank on large matrices (>2048 dims). 0 to disable and use full SVD" + " / 大行列(2048次元超)に対するsvd_lowrankの反復回数。0で無効化し完全SVDを使用", + ) return parser