fix formatting in resize_lora.py

This commit is contained in:
mgz
2024-02-03 20:09:37 -06:00
parent cd19df49cd
commit bf2de5620c

View File

@@ -159,7 +159,6 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
new_rank = rank
new_alpha = float(scale*new_rank)
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
new_rank = 1
new_alpha = float(scale*new_rank)
@@ -167,7 +166,6 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
new_rank = rank
new_alpha = float(scale*new_rank)
# Calculate resize info
s_sum = torch.sum(torch.abs(S))
s_rank = torch.sum(torch.abs(S[:new_rank]))
@@ -260,7 +258,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
if verbose and dynamic_method:
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
else:
verbose_str+=f"\n"
verbose_str += "\n"
new_alpha = param_dict['new_alpha']
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
@@ -283,10 +281,15 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
def resize(args):
if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')):
if (
args.save_to is None or
not (args.save_to.endswith('.ckpt') or
args.save_to.endswith('.pt') or
args.save_to.endswith('.pth') or
args.save_to.endswith('.safetensors'))
):
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
def str_to_dtype(p):
if p == 'float':
return torch.float