diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 2f586a8a..2a44592b 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -85,13 +85,13 @@ def index_sv_ratio(S, target): # Modified from Kohaku-blueleaf's extract/merge functions -def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): +def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2): out_size, in_size, kernel_size, _ = weight.size() weight = weight.reshape(out_size, -1) _in_size = in_size * kernel_size * kernel_size - if out_size > 2048 and _in_size > 2048: - U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size)) + if svd_lowrank_niter > 0 and out_size > 2048 and _in_size > 2048: + U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size), niter=svd_lowrank_niter) Vh = V.T else: U, S, Vh = torch.linalg.svd(weight.to(device)) @@ -110,11 +110,11 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale return param_dict -def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): +def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2): out_size, in_size = weight.size() - if out_size > 2048 and in_size > 2048: - U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size)) + if svd_lowrank_niter > 0 and out_size > 2048 and in_size > 2048: + U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size), niter=svd_lowrank_niter) Vh = V.T else: U, S, Vh = torch.linalg.svd(weight.to(device)) @@ -209,7 +209,7 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): return param_dict -def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): +def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose, svd_lowrank_niter=2): max_old_rank = None new_alpha = None verbose_str = "\n" @@ -273,10 +273,10 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna if conv2d: full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) - param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale) + param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter) else: full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) - param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter) if verbose: max_ratio = param_dict["max_ratio"] @@ -347,7 +347,7 @@ def resize(args): logger.info("Resizing Lora...") state_dict, old_dim, new_alpha = resize_lora_model( - lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose + lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose, args.svd_lowrank_niter ) # update metadata @@ -425,6 +425,13 @@ def setup_parser() -> argparse.ArgumentParser: help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank", ) parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction") + parser.add_argument( + "--svd_lowrank_niter", + type=int, + default=2, + help="Number of iterations for svd_lowrank on large matrices (>2048 dims). 0 to disable and use full SVD" + " / 大行列(2048次元超)に対するsvd_lowrankの反復回数。0で無効化し完全SVDを使用", + ) return parser