diff --git a/networks/resize_lora.py b/networks/resize_lora.py index de405613..eb745333 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -59,14 +59,55 @@ def index_sv_fro(S, target): return index +def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): + param_dict = {} + + if dynamic_method=="sv_ratio": + # Calculate new dim and alpha based off ratio + max_sv = S[0] + min_sv = max_sv/dynamic_param + new_rank = max(torch.sum(S > min_sv).item(),1) + new_alpha = float(scale*new_rank) + + elif dynamic_method=="sv_cumulative": + # Calculate new dim and alpha based off cumulative sum + new_rank = index_sv_cumulative(S, dynamic_param) + new_rank = max(new_rank, 1) + new_alpha = float(scale*new_rank) + + elif dynamic_method=="sv_fro": + # Calculate new dim and alpha based off sqrt sum of squares + new_rank = index_sv_fro(S, dynamic_param) + new_rank = min(max(new_rank, 1), len(S)-1) + new_alpha = float(scale*new_rank) + else: + new_rank = rank + new_alpha = float(scale*new_rank) + + # Calculate resize info + s_sum = torch.sum(torch.abs(S)) + s_rank = torch.sum(torch.abs(S[:new_rank])) + + S_squared = S.pow(2) + s_fro = torch.sqrt(torch.sum(S_squared)) + s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) + fro_percent = float(s_red_fro/s_fro) + + param_dict["new_rank"] = new_rank + param_dict["new_alpha"] = new_alpha + param_dict["sum_retained"] = (s_rank)/s_sum + param_dict["fro_retained"] = fro_percent + param_dict["max_ratio"] = S[0]/S[new_rank] + + return param_dict + + def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): network_alpha = None network_dim = None verbose_str = "\n" fro_list = [] - CLAMP_QUANTILE = 0.99 - # Extract loaded lora dim and alpha for key, value in lora_sd.items(): if network_alpha is None and 'alpha' in key: @@ -82,9 +123,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn if dynamic_method: print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}") - else: - new_alpha = float(scale*new_rank) # calculate new alpha from scale - print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new dim: {new_rank}, new alpha: {new_alpha}") lora_down_weight = None lora_up_weight = None @@ -93,7 +131,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn block_down_name = None block_up_name = None - print("resizing lora...") with torch.no_grad(): for key, value in tqdm(lora_sd.items()): if 'lora_down' in key: @@ -122,39 +159,21 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn U, S, Vh = torch.linalg.svd(full_weight_matrix) - if dynamic_method=="sv_ratio": - # Calculate new dim and alpha based off ratio - max_sv = S[0] - min_sv = max_sv/dynamic_param - new_rank = torch.sum(S > min_sv).item() - new_rank = max(new_rank, 1) - new_alpha = float(scale*new_rank) - elif dynamic_method=="sv_cumulative": - # Calculate new dim and alpha based off cumulative sum - new_rank = index_sv_cumulative(S, dynamic_param) - new_rank = max(new_rank, 1) - new_alpha = float(scale*new_rank) + param_dict = rank_resize(S, new_rank, dynamic_method, dynamic_param, scale) + + new_rank = param_dict['new_rank'] + new_alpha = param_dict['new_alpha'] - elif dynamic_method=="sv_fro": - # Calculate new dim and alpha based off sqrt sum of squares - new_rank = index_sv_fro(S, dynamic_param) - new_rank = max(new_rank, 1) - new_alpha = float(scale*new_rank) - if verbose: - s_sum = torch.sum(torch.abs(S)) - s_rank = torch.sum(torch.abs(S[:new_rank])) - - S_squared = S.pow(2) - s_fro = torch.sqrt(torch.sum(S_squared)) - s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) - fro_percent = float(s_red_fro/s_fro) - if not np.isnan(fro_percent): - fro_list.append(float(fro_percent)) + max_ratio = param_dict['max_ratio'] + sum_retained = param_dict['sum_retained'] + fro_retained = param_dict['fro_retained'] + if not np.isnan(fro_retained): + fro_list.append(float(fro_retained)) verbose_str+=f"{block_down_name:75} | " - verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, fro retained: {fro_percent:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}" + verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" if verbose and dynamic_method: @@ -168,12 +187,11 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn Vh = Vh[:new_rank, :] - dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, CLAMP_QUANTILE) - low_val = -hi_val - - U = U.clamp(low_val, hi_val) - Vh = Vh.clamp(low_val, hi_val) + # dist = torch.cat([U.flatten(), Vh.flatten()]) + # hi_val = torch.quantile(dist, CLAMP_QUANTILE) + # low_val = -hi_val + # U = U.clamp(low_val, hi_val) + # Vh = Vh.clamp(low_val, hi_val) if conv2d: U = U.unsqueeze(2).unsqueeze(3) @@ -223,7 +241,7 @@ def resize(args): print("loading Model...") lora_sd, metadata = load_state_dict(args.model, merge_dtype) - print("resizing rank...") + print("Resizing Lora...") state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) # update metadata