mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix formatting in resize_lora.py
This commit is contained in:
@@ -141,17 +141,17 @@ def merge_linear(lora_down, lora_up, device):
|
|||||||
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||||
param_dict = {}
|
param_dict = {}
|
||||||
|
|
||||||
if dynamic_method=="sv_ratio":
|
if dynamic_method == "sv_ratio":
|
||||||
# Calculate new dim and alpha based off ratio
|
# Calculate new dim and alpha based off ratio
|
||||||
new_rank = index_sv_ratio(S, dynamic_param) + 1
|
new_rank = index_sv_ratio(S, dynamic_param) + 1
|
||||||
new_alpha = float(scale*new_rank)
|
new_alpha = float(scale*new_rank)
|
||||||
|
|
||||||
elif dynamic_method=="sv_cumulative":
|
elif dynamic_method == "sv_cumulative":
|
||||||
# Calculate new dim and alpha based off cumulative sum
|
# Calculate new dim and alpha based off cumulative sum
|
||||||
new_rank = index_sv_cumulative(S, dynamic_param) + 1
|
new_rank = index_sv_cumulative(S, dynamic_param) + 1
|
||||||
new_alpha = float(scale*new_rank)
|
new_alpha = float(scale*new_rank)
|
||||||
|
|
||||||
elif dynamic_method=="sv_fro":
|
elif dynamic_method == "sv_fro":
|
||||||
# Calculate new dim and alpha based off sqrt sum of squares
|
# Calculate new dim and alpha based off sqrt sum of squares
|
||||||
new_rank = index_sv_fro(S, dynamic_param) + 1
|
new_rank = index_sv_fro(S, dynamic_param) + 1
|
||||||
new_alpha = float(scale*new_rank)
|
new_alpha = float(scale*new_rank)
|
||||||
@@ -159,7 +159,6 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
|||||||
new_rank = rank
|
new_rank = rank
|
||||||
new_alpha = float(scale*new_rank)
|
new_alpha = float(scale*new_rank)
|
||||||
|
|
||||||
|
|
||||||
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
||||||
new_rank = 1
|
new_rank = 1
|
||||||
new_alpha = float(scale*new_rank)
|
new_alpha = float(scale*new_rank)
|
||||||
@@ -167,7 +166,6 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
|||||||
new_rank = rank
|
new_rank = rank
|
||||||
new_alpha = float(scale*new_rank)
|
new_alpha = float(scale*new_rank)
|
||||||
|
|
||||||
|
|
||||||
# Calculate resize info
|
# Calculate resize info
|
||||||
s_sum = torch.sum(torch.abs(S))
|
s_sum = torch.sum(torch.abs(S))
|
||||||
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
||||||
@@ -254,13 +252,13 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
if not np.isnan(fro_retained):
|
if not np.isnan(fro_retained):
|
||||||
fro_list.append(float(fro_retained))
|
fro_list.append(float(fro_retained))
|
||||||
|
|
||||||
verbose_str+=f"{block_down_name:75} | "
|
verbose_str += f"{block_down_name:75} | "
|
||||||
verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
verbose_str += f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
||||||
|
|
||||||
if verbose and dynamic_method:
|
if verbose and dynamic_method:
|
||||||
verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
||||||
else:
|
else:
|
||||||
verbose_str+=f"\n"
|
verbose_str += "\n"
|
||||||
|
|
||||||
new_alpha = param_dict['new_alpha']
|
new_alpha = param_dict['new_alpha']
|
||||||
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||||
@@ -283,10 +281,15 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
|
|
||||||
|
|
||||||
def resize(args):
|
def resize(args):
|
||||||
if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')):
|
if (
|
||||||
|
args.save_to is None or
|
||||||
|
not (args.save_to.endswith('.ckpt') or
|
||||||
|
args.save_to.endswith('.pt') or
|
||||||
|
args.save_to.endswith('.pth') or
|
||||||
|
args.save_to.endswith('.safetensors'))
|
||||||
|
):
|
||||||
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
|
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
|
||||||
|
|
||||||
|
|
||||||
def str_to_dtype(p):
|
def str_to_dtype(p):
|
||||||
if p == 'float':
|
if p == 'float':
|
||||||
return torch.float
|
return torch.float
|
||||||
|
|||||||
Reference in New Issue
Block a user