fix formatting in resize_lora.py

This commit is contained in:
mgz
2024-02-03 20:09:37 -06:00
parent cd19df49cd
commit bf2de5620c

View File

@@ -141,17 +141,17 @@ def merge_linear(lora_down, lora_up, device):
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
param_dict = {} param_dict = {}
if dynamic_method=="sv_ratio": if dynamic_method == "sv_ratio":
# Calculate new dim and alpha based off ratio # Calculate new dim and alpha based off ratio
new_rank = index_sv_ratio(S, dynamic_param) + 1 new_rank = index_sv_ratio(S, dynamic_param) + 1
new_alpha = float(scale*new_rank) new_alpha = float(scale*new_rank)
elif dynamic_method=="sv_cumulative": elif dynamic_method == "sv_cumulative":
# Calculate new dim and alpha based off cumulative sum # Calculate new dim and alpha based off cumulative sum
new_rank = index_sv_cumulative(S, dynamic_param) + 1 new_rank = index_sv_cumulative(S, dynamic_param) + 1
new_alpha = float(scale*new_rank) new_alpha = float(scale*new_rank)
elif dynamic_method=="sv_fro": elif dynamic_method == "sv_fro":
# Calculate new dim and alpha based off sqrt sum of squares # Calculate new dim and alpha based off sqrt sum of squares
new_rank = index_sv_fro(S, dynamic_param) + 1 new_rank = index_sv_fro(S, dynamic_param) + 1
new_alpha = float(scale*new_rank) new_alpha = float(scale*new_rank)
@@ -159,7 +159,6 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
new_rank = rank new_rank = rank
new_alpha = float(scale*new_rank) new_alpha = float(scale*new_rank)
if S[0] <= MIN_SV: # Zero matrix, set dim to 1 if S[0] <= MIN_SV: # Zero matrix, set dim to 1
new_rank = 1 new_rank = 1
new_alpha = float(scale*new_rank) new_alpha = float(scale*new_rank)
@@ -167,7 +166,6 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
new_rank = rank new_rank = rank
new_alpha = float(scale*new_rank) new_alpha = float(scale*new_rank)
# Calculate resize info # Calculate resize info
s_sum = torch.sum(torch.abs(S)) s_sum = torch.sum(torch.abs(S))
s_rank = torch.sum(torch.abs(S[:new_rank])) s_rank = torch.sum(torch.abs(S[:new_rank]))
@@ -254,13 +252,13 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
if not np.isnan(fro_retained): if not np.isnan(fro_retained):
fro_list.append(float(fro_retained)) fro_list.append(float(fro_retained))
verbose_str+=f"{block_down_name:75} | " verbose_str += f"{block_down_name:75} | "
verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" verbose_str += f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
if verbose and dynamic_method: if verbose and dynamic_method:
verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
else: else:
verbose_str+=f"\n" verbose_str += "\n"
new_alpha = param_dict['new_alpha'] new_alpha = param_dict['new_alpha']
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
@@ -283,10 +281,15 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
def resize(args): def resize(args):
if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')): if (
args.save_to is None or
not (args.save_to.endswith('.ckpt') or
args.save_to.endswith('.pt') or
args.save_to.endswith('.pth') or
args.save_to.endswith('.safetensors'))
):
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.") raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
def str_to_dtype(p): def str_to_dtype(p):
if p == 'float': if p == 'float':
return torch.float return torch.float