mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Merge pull request #2296 from kohya-ss/add-svd-lowrank-niter
Add --svd_lowrank_niter option to resize_lora.py
This commit is contained in:
@@ -85,13 +85,13 @@ def index_sv_ratio(S, target):
|
|||||||
|
|
||||||
|
|
||||||
# Modified from Kohaku-blueleaf's extract/merge functions
|
# Modified from Kohaku-blueleaf's extract/merge functions
|
||||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2):
|
||||||
out_size, in_size, kernel_size, _ = weight.size()
|
out_size, in_size, kernel_size, _ = weight.size()
|
||||||
weight = weight.reshape(out_size, -1)
|
weight = weight.reshape(out_size, -1)
|
||||||
_in_size = in_size * kernel_size * kernel_size
|
_in_size = in_size * kernel_size * kernel_size
|
||||||
|
|
||||||
if out_size > 2048 and _in_size > 2048:
|
if svd_lowrank_niter > 0 and out_size > 2048 and _in_size > 2048:
|
||||||
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size))
|
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size), niter=svd_lowrank_niter)
|
||||||
Vh = V.T
|
Vh = V.T
|
||||||
else:
|
else:
|
||||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||||
@@ -110,11 +110,11 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
|
|||||||
return param_dict
|
return param_dict
|
||||||
|
|
||||||
|
|
||||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2):
|
||||||
out_size, in_size = weight.size()
|
out_size, in_size = weight.size()
|
||||||
|
|
||||||
if out_size > 2048 and in_size > 2048:
|
if svd_lowrank_niter > 0 and out_size > 2048 and in_size > 2048:
|
||||||
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size))
|
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size), niter=svd_lowrank_niter)
|
||||||
Vh = V.T
|
Vh = V.T
|
||||||
else:
|
else:
|
||||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||||
@@ -209,7 +209,7 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
|||||||
return param_dict
|
return param_dict
|
||||||
|
|
||||||
|
|
||||||
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose, svd_lowrank_niter=2):
|
||||||
max_old_rank = None
|
max_old_rank = None
|
||||||
new_alpha = None
|
new_alpha = None
|
||||||
verbose_str = "\n"
|
verbose_str = "\n"
|
||||||
@@ -273,10 +273,10 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
|||||||
|
|
||||||
if conv2d:
|
if conv2d:
|
||||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||||
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale)
|
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter)
|
||||||
else:
|
else:
|
||||||
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
||||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
max_ratio = param_dict["max_ratio"]
|
max_ratio = param_dict["max_ratio"]
|
||||||
@@ -347,7 +347,7 @@ def resize(args):
|
|||||||
|
|
||||||
logger.info("Resizing Lora...")
|
logger.info("Resizing Lora...")
|
||||||
state_dict, old_dim, new_alpha = resize_lora_model(
|
state_dict, old_dim, new_alpha = resize_lora_model(
|
||||||
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose
|
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose, args.svd_lowrank_niter
|
||||||
)
|
)
|
||||||
|
|
||||||
# update metadata
|
# update metadata
|
||||||
@@ -425,6 +425,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank",
|
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank",
|
||||||
)
|
)
|
||||||
parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction")
|
parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction")
|
||||||
|
parser.add_argument(
|
||||||
|
"--svd_lowrank_niter",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Number of iterations for svd_lowrank on large matrices (>2048 dims). 0 to disable and use full SVD"
|
||||||
|
" / 大行列(2048次元超)に対するsvd_lowrankの反復回数。0で無効化し完全SVDを使用",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user