add verbosity option for resize_lora.py

add --verbose flag to print additional statistics during resize_lora function
correct some parameter references in resize_lora_model function
This commit is contained in:
michaelgzhang
2023-02-11 02:38:13 -06:00
parent b32abdd327
commit 55521eece0

View File

@@ -38,9 +38,10 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
torch.save(model, file_name) torch.save(model, file_name)
def resize_lora_model(lora_sd, new_rank, save_dtype, device): def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
network_alpha = None network_alpha = None
network_dim = None network_dim = None
verbose_str = "\n"
CLAMP_QUANTILE = 0.99 CLAMP_QUANTILE = 0.99
@@ -96,6 +97,12 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device):
U, S, Vh = torch.linalg.svd(full_weight_matrix) U, S, Vh = torch.linalg.svd(full_weight_matrix)
if verbose:
s_sum = torch.sum(torch.abs(S))
s_rank = torch.sum(torch.abs(S[:new_rank]))
verbose_str+=f"{block_down_name:76} | "
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}%, max(S) to max(S_dropped) ratio: {S[0]/S[new_rank]:0.1f}\n"
U = U[:, :new_rank] U = U[:, :new_rank]
S = S[:new_rank] S = S[:new_rank]
U = U @ torch.diag(S) U = U @ torch.diag(S)
@@ -113,7 +120,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device):
U = U.unsqueeze(2).unsqueeze(3) U = U.unsqueeze(2).unsqueeze(3)
Vh = Vh.unsqueeze(2).unsqueeze(3) Vh = Vh.unsqueeze(2).unsqueeze(3)
if args.device: if device:
U = U.to(org_device) U = U.to(org_device)
Vh = Vh.to(org_device) Vh = Vh.to(org_device)
@@ -127,6 +134,8 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device):
lora_up_weight = None lora_up_weight = None
weights_loaded = False weights_loaded = False
if verbose:
print(verbose_str)
print("resizing complete") print("resizing complete")
return o_lora_sd, network_dim, new_alpha return o_lora_sd, network_dim, new_alpha
@@ -151,7 +160,7 @@ def resize(args):
lora_sd, metadata = load_state_dict(args.model, merge_dtype) lora_sd, metadata = load_state_dict(args.model, merge_dtype)
print("resizing rank...") print("resizing rank...")
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device) state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
# update metadata # update metadata
if metadata is None: if metadata is None:
@@ -182,6 +191,8 @@ if __name__ == '__main__':
parser.add_argument("--model", type=str, default=None, parser.add_argument("--model", type=str, default=None,
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors") help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument("--verbose", action="store_true",
help="Display verbose resizing information")
args = parser.parse_args() args = parser.parse_args()
resize(args) resize(args)