mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
refactor and bug fix for too large sv_ratio
- code refactor to be able to re-use same function for dynamic extract lora - remove clamp - fix issue where if sv_ratio is too high index goes out of bounds
This commit is contained in:
@@ -59,14 +59,55 @@ def index_sv_fro(S, target):
|
|||||||
return index
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||||
|
param_dict = {}
|
||||||
|
|
||||||
|
if dynamic_method=="sv_ratio":
|
||||||
|
# Calculate new dim and alpha based off ratio
|
||||||
|
max_sv = S[0]
|
||||||
|
min_sv = max_sv/dynamic_param
|
||||||
|
new_rank = max(torch.sum(S > min_sv).item(),1)
|
||||||
|
new_alpha = float(scale*new_rank)
|
||||||
|
|
||||||
|
elif dynamic_method=="sv_cumulative":
|
||||||
|
# Calculate new dim and alpha based off cumulative sum
|
||||||
|
new_rank = index_sv_cumulative(S, dynamic_param)
|
||||||
|
new_rank = max(new_rank, 1)
|
||||||
|
new_alpha = float(scale*new_rank)
|
||||||
|
|
||||||
|
elif dynamic_method=="sv_fro":
|
||||||
|
# Calculate new dim and alpha based off sqrt sum of squares
|
||||||
|
new_rank = index_sv_fro(S, dynamic_param)
|
||||||
|
new_rank = min(max(new_rank, 1), len(S)-1)
|
||||||
|
new_alpha = float(scale*new_rank)
|
||||||
|
else:
|
||||||
|
new_rank = rank
|
||||||
|
new_alpha = float(scale*new_rank)
|
||||||
|
|
||||||
|
# Calculate resize info
|
||||||
|
s_sum = torch.sum(torch.abs(S))
|
||||||
|
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
||||||
|
|
||||||
|
S_squared = S.pow(2)
|
||||||
|
s_fro = torch.sqrt(torch.sum(S_squared))
|
||||||
|
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
||||||
|
fro_percent = float(s_red_fro/s_fro)
|
||||||
|
|
||||||
|
param_dict["new_rank"] = new_rank
|
||||||
|
param_dict["new_alpha"] = new_alpha
|
||||||
|
param_dict["sum_retained"] = (s_rank)/s_sum
|
||||||
|
param_dict["fro_retained"] = fro_percent
|
||||||
|
param_dict["max_ratio"] = S[0]/S[new_rank]
|
||||||
|
|
||||||
|
return param_dict
|
||||||
|
|
||||||
|
|
||||||
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
||||||
network_alpha = None
|
network_alpha = None
|
||||||
network_dim = None
|
network_dim = None
|
||||||
verbose_str = "\n"
|
verbose_str = "\n"
|
||||||
fro_list = []
|
fro_list = []
|
||||||
|
|
||||||
CLAMP_QUANTILE = 0.99
|
|
||||||
|
|
||||||
# Extract loaded lora dim and alpha
|
# Extract loaded lora dim and alpha
|
||||||
for key, value in lora_sd.items():
|
for key, value in lora_sd.items():
|
||||||
if network_alpha is None and 'alpha' in key:
|
if network_alpha is None and 'alpha' in key:
|
||||||
@@ -82,9 +123,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
|
|
||||||
if dynamic_method:
|
if dynamic_method:
|
||||||
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}")
|
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}")
|
||||||
else:
|
|
||||||
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
|
||||||
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new dim: {new_rank}, new alpha: {new_alpha}")
|
|
||||||
|
|
||||||
lora_down_weight = None
|
lora_down_weight = None
|
||||||
lora_up_weight = None
|
lora_up_weight = None
|
||||||
@@ -93,7 +131,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
block_down_name = None
|
block_down_name = None
|
||||||
block_up_name = None
|
block_up_name = None
|
||||||
|
|
||||||
print("resizing lora...")
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for key, value in tqdm(lora_sd.items()):
|
for key, value in tqdm(lora_sd.items()):
|
||||||
if 'lora_down' in key:
|
if 'lora_down' in key:
|
||||||
@@ -122,39 +159,21 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
|
|
||||||
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
||||||
|
|
||||||
if dynamic_method=="sv_ratio":
|
|
||||||
# Calculate new dim and alpha based off ratio
|
|
||||||
max_sv = S[0]
|
|
||||||
min_sv = max_sv/dynamic_param
|
|
||||||
new_rank = torch.sum(S > min_sv).item()
|
|
||||||
new_rank = max(new_rank, 1)
|
|
||||||
new_alpha = float(scale*new_rank)
|
|
||||||
|
|
||||||
elif dynamic_method=="sv_cumulative":
|
param_dict = rank_resize(S, new_rank, dynamic_method, dynamic_param, scale)
|
||||||
# Calculate new dim and alpha based off cumulative sum
|
|
||||||
new_rank = index_sv_cumulative(S, dynamic_param)
|
new_rank = param_dict['new_rank']
|
||||||
new_rank = max(new_rank, 1)
|
new_alpha = param_dict['new_alpha']
|
||||||
new_alpha = float(scale*new_rank)
|
|
||||||
|
|
||||||
elif dynamic_method=="sv_fro":
|
|
||||||
# Calculate new dim and alpha based off sqrt sum of squares
|
|
||||||
new_rank = index_sv_fro(S, dynamic_param)
|
|
||||||
new_rank = max(new_rank, 1)
|
|
||||||
new_alpha = float(scale*new_rank)
|
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
s_sum = torch.sum(torch.abs(S))
|
max_ratio = param_dict['max_ratio']
|
||||||
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
sum_retained = param_dict['sum_retained']
|
||||||
|
fro_retained = param_dict['fro_retained']
|
||||||
S_squared = S.pow(2)
|
if not np.isnan(fro_retained):
|
||||||
s_fro = torch.sqrt(torch.sum(S_squared))
|
fro_list.append(float(fro_retained))
|
||||||
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
|
||||||
fro_percent = float(s_red_fro/s_fro)
|
|
||||||
if not np.isnan(fro_percent):
|
|
||||||
fro_list.append(float(fro_percent))
|
|
||||||
|
|
||||||
verbose_str+=f"{block_down_name:75} | "
|
verbose_str+=f"{block_down_name:75} | "
|
||||||
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, fro retained: {fro_percent:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}"
|
verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
||||||
|
|
||||||
|
|
||||||
if verbose and dynamic_method:
|
if verbose and dynamic_method:
|
||||||
@@ -168,12 +187,11 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
|
|
||||||
Vh = Vh[:new_rank, :]
|
Vh = Vh[:new_rank, :]
|
||||||
|
|
||||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
# dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
# hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||||
low_val = -hi_val
|
# low_val = -hi_val
|
||||||
|
# U = U.clamp(low_val, hi_val)
|
||||||
U = U.clamp(low_val, hi_val)
|
# Vh = Vh.clamp(low_val, hi_val)
|
||||||
Vh = Vh.clamp(low_val, hi_val)
|
|
||||||
|
|
||||||
if conv2d:
|
if conv2d:
|
||||||
U = U.unsqueeze(2).unsqueeze(3)
|
U = U.unsqueeze(2).unsqueeze(3)
|
||||||
@@ -223,7 +241,7 @@ def resize(args):
|
|||||||
print("loading Model...")
|
print("loading Model...")
|
||||||
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
||||||
|
|
||||||
print("resizing rank...")
|
print("Resizing Lora...")
|
||||||
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
|
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
|
||||||
|
|
||||||
# update metadata
|
# update metadata
|
||||||
|
|||||||
Reference in New Issue
Block a user