From efe4c983410dfb02185cf3cef4851191e4380f1e Mon Sep 17 00:00:00 2001 From: mgz-dev <49577754+mgz-dev@users.noreply.github.com> Date: Tue, 28 Feb 2023 14:55:15 -0600 Subject: [PATCH] Enable ability to resize lora dim based off ratios --- networks/resize_lora.py | 44 +++++++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 271de8ef..c4d8a4d8 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -38,10 +38,11 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): torch.save(model, file_name) -def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): +def resize_lora_model(lora_sd, new_rank, save_dtype, device, sv_ratio, verbose): network_alpha = None network_dim = None verbose_str = "\n" + ratio_flag = False CLAMP_QUANTILE = 0.99 @@ -57,9 +58,12 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): network_alpha = network_dim scale = network_alpha/network_dim - new_alpha = float(scale*new_rank) # calculate new alpha from scale - - print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}") + if not sv_ratio: + new_alpha = float(scale*new_rank) # calculate new alpha from scale + print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new dim: {new_rank}, new alpha: {new_alpha}") + else: + print(f"Dynamically determining new alphas and dims based off sv ratio: {sv_ratio}") + ratio_flag = True lora_down_weight = None lora_up_weight = None @@ -97,11 +101,24 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose): U, S, Vh = torch.linalg.svd(full_weight_matrix) + if ratio_flag: + # Calculate new dim and alpha for dynamic sizing + max_sv = S[0] + min_sv = max_sv/sv_ratio + new_rank = torch.sum(S > min_sv).item() + new_rank = max(new_rank, 1) + new_alpha = float(scale*new_rank) + if verbose: s_sum = torch.sum(torch.abs(S)) s_rank = torch.sum(torch.abs(S[:new_rank])) - verbose_str+=f"{block_down_name:76} | " - verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n" + verbose_str+=f"{block_down_name:75} | " + verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}" + + if verbose and ratio_flag: + verbose_str+=f", dynamic| dim: {new_rank}, alpha: {new_alpha}\n" + else: + verbose_str+=f"\n" U = U[:, :new_rank] S = S[:new_rank] @@ -160,16 +177,21 @@ def resize(args): lora_sd, metadata = load_state_dict(args.model, merge_dtype) print("resizing rank...") - state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose) + state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.sv_ratio, args.verbose) # update metadata if metadata is None: metadata = {} comment = metadata.get("ss_training_comment", "") - metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" - metadata["ss_network_dim"] = str(args.new_rank) - metadata["ss_network_alpha"] = str(new_alpha) + if not args.sv_ratio: + metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" + metadata["ss_network_dim"] = str(args.new_rank) + metadata["ss_network_alpha"] = str(new_alpha) + else: + metadata["ss_training_comment"] = f"Dynamic resize from {old_dim} with ratio {args.sv_ratio}; {comment}" + metadata["ss_network_dim"] = 'Dynamic' + metadata["ss_network_alpha"] = 'Dynamic' model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -193,6 +215,8 @@ if __name__ == '__main__': parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する") + parser.add_argument("--sv_ratio", type=float, default=None, + help="Specify svd ratio for dim calcs. Will override --new_rank") args = parser.parse_args() resize(args)