mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix formatting in resize_lora.py
This commit is contained in:
@@ -159,7 +159,6 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
|||||||
new_rank = rank
|
new_rank = rank
|
||||||
new_alpha = float(scale*new_rank)
|
new_alpha = float(scale*new_rank)
|
||||||
|
|
||||||
|
|
||||||
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
||||||
new_rank = 1
|
new_rank = 1
|
||||||
new_alpha = float(scale*new_rank)
|
new_alpha = float(scale*new_rank)
|
||||||
@@ -167,7 +166,6 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
|||||||
new_rank = rank
|
new_rank = rank
|
||||||
new_alpha = float(scale*new_rank)
|
new_alpha = float(scale*new_rank)
|
||||||
|
|
||||||
|
|
||||||
# Calculate resize info
|
# Calculate resize info
|
||||||
s_sum = torch.sum(torch.abs(S))
|
s_sum = torch.sum(torch.abs(S))
|
||||||
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
||||||
@@ -260,7 +258,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
if verbose and dynamic_method:
|
if verbose and dynamic_method:
|
||||||
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
||||||
else:
|
else:
|
||||||
verbose_str+=f"\n"
|
verbose_str += "\n"
|
||||||
|
|
||||||
new_alpha = param_dict['new_alpha']
|
new_alpha = param_dict['new_alpha']
|
||||||
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||||
@@ -283,10 +281,15 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
|
|
||||||
|
|
||||||
def resize(args):
|
def resize(args):
|
||||||
if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')):
|
if (
|
||||||
|
args.save_to is None or
|
||||||
|
not (args.save_to.endswith('.ckpt') or
|
||||||
|
args.save_to.endswith('.pt') or
|
||||||
|
args.save_to.endswith('.pth') or
|
||||||
|
args.save_to.endswith('.safetensors'))
|
||||||
|
):
|
||||||
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
|
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
|
||||||
|
|
||||||
|
|
||||||
def str_to_dtype(p):
|
def str_to_dtype(p):
|
||||||
if p == 'float':
|
if p == 'float':
|
||||||
return torch.float
|
return torch.float
|
||||||
|
|||||||
Reference in New Issue
Block a user