mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
Merge pull request #2296 from kohya-ss/add-svd-lowrank-niter
Add --svd_lowrank_niter option to resize_lora.py
This commit is contained in:
@@ -85,13 +85,13 @@ def index_sv_ratio(S, target):
|
||||
|
||||
|
||||
# Modified from Kohaku-blueleaf's extract/merge functions
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2):
|
||||
out_size, in_size, kernel_size, _ = weight.size()
|
||||
weight = weight.reshape(out_size, -1)
|
||||
_in_size = in_size * kernel_size * kernel_size
|
||||
|
||||
if out_size > 2048 and _in_size > 2048:
|
||||
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size))
|
||||
if svd_lowrank_niter > 0 and out_size > 2048 and _in_size > 2048:
|
||||
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size), niter=svd_lowrank_niter)
|
||||
Vh = V.T
|
||||
else:
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
@@ -110,11 +110,11 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
|
||||
return param_dict
|
||||
|
||||
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2):
|
||||
out_size, in_size = weight.size()
|
||||
|
||||
if out_size > 2048 and in_size > 2048:
|
||||
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size))
|
||||
if svd_lowrank_niter > 0 and out_size > 2048 and in_size > 2048:
|
||||
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size), niter=svd_lowrank_niter)
|
||||
Vh = V.T
|
||||
else:
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
@@ -209,7 +209,7 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
return param_dict
|
||||
|
||||
|
||||
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
||||
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose, svd_lowrank_niter=2):
|
||||
max_old_rank = None
|
||||
new_alpha = None
|
||||
verbose_str = "\n"
|
||||
@@ -273,10 +273,10 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
|
||||
if conv2d:
|
||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale)
|
||||
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter)
|
||||
else:
|
||||
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter)
|
||||
|
||||
if verbose:
|
||||
max_ratio = param_dict["max_ratio"]
|
||||
@@ -347,7 +347,7 @@ def resize(args):
|
||||
|
||||
logger.info("Resizing Lora...")
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(
|
||||
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose
|
||||
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose, args.svd_lowrank_niter
|
||||
)
|
||||
|
||||
# update metadata
|
||||
@@ -425,6 +425,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank",
|
||||
)
|
||||
parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction")
|
||||
parser.add_argument(
|
||||
"--svd_lowrank_niter",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of iterations for svd_lowrank on large matrices (>2048 dims). 0 to disable and use full SVD"
|
||||
" / 大行列(2048次元超)に対するsvd_lowrankの反復回数。0で無効化し完全SVDを使用",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
Reference in New Issue
Block a user