refactor and bug fix for too large sv_ratio

- code refactor to be able to re-use same function for dynamic extract lora
- remove clamp
- fix issue where if sv_ratio is too high index goes out of bounds
This commit is contained in:
mgz-dev
2023-03-03 23:32:46 -06:00
parent 52ca6c515c
commit 80be6fa130

View File

@@ -59,14 +59,55 @@ def index_sv_fro(S, target):
return index return index
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
param_dict = {}
if dynamic_method=="sv_ratio":
# Calculate new dim and alpha based off ratio
max_sv = S[0]
min_sv = max_sv/dynamic_param
new_rank = max(torch.sum(S > min_sv).item(),1)
new_alpha = float(scale*new_rank)
elif dynamic_method=="sv_cumulative":
# Calculate new dim and alpha based off cumulative sum
new_rank = index_sv_cumulative(S, dynamic_param)
new_rank = max(new_rank, 1)
new_alpha = float(scale*new_rank)
elif dynamic_method=="sv_fro":
# Calculate new dim and alpha based off sqrt sum of squares
new_rank = index_sv_fro(S, dynamic_param)
new_rank = min(max(new_rank, 1), len(S)-1)
new_alpha = float(scale*new_rank)
else:
new_rank = rank
new_alpha = float(scale*new_rank)
# Calculate resize info
s_sum = torch.sum(torch.abs(S))
s_rank = torch.sum(torch.abs(S[:new_rank]))
S_squared = S.pow(2)
s_fro = torch.sqrt(torch.sum(S_squared))
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
fro_percent = float(s_red_fro/s_fro)
param_dict["new_rank"] = new_rank
param_dict["new_alpha"] = new_alpha
param_dict["sum_retained"] = (s_rank)/s_sum
param_dict["fro_retained"] = fro_percent
param_dict["max_ratio"] = S[0]/S[new_rank]
return param_dict
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
network_alpha = None network_alpha = None
network_dim = None network_dim = None
verbose_str = "\n" verbose_str = "\n"
fro_list = [] fro_list = []
CLAMP_QUANTILE = 0.99
# Extract loaded lora dim and alpha # Extract loaded lora dim and alpha
for key, value in lora_sd.items(): for key, value in lora_sd.items():
if network_alpha is None and 'alpha' in key: if network_alpha is None and 'alpha' in key:
@@ -82,9 +123,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
if dynamic_method: if dynamic_method:
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}") print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}")
else:
new_alpha = float(scale*new_rank) # calculate new alpha from scale
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new dim: {new_rank}, new alpha: {new_alpha}")
lora_down_weight = None lora_down_weight = None
lora_up_weight = None lora_up_weight = None
@@ -93,7 +131,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
block_down_name = None block_down_name = None
block_up_name = None block_up_name = None
print("resizing lora...")
with torch.no_grad(): with torch.no_grad():
for key, value in tqdm(lora_sd.items()): for key, value in tqdm(lora_sd.items()):
if 'lora_down' in key: if 'lora_down' in key:
@@ -122,39 +159,21 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
U, S, Vh = torch.linalg.svd(full_weight_matrix) U, S, Vh = torch.linalg.svd(full_weight_matrix)
if dynamic_method=="sv_ratio":
# Calculate new dim and alpha based off ratio
max_sv = S[0]
min_sv = max_sv/dynamic_param
new_rank = torch.sum(S > min_sv).item()
new_rank = max(new_rank, 1)
new_alpha = float(scale*new_rank)
elif dynamic_method=="sv_cumulative": param_dict = rank_resize(S, new_rank, dynamic_method, dynamic_param, scale)
# Calculate new dim and alpha based off cumulative sum
new_rank = index_sv_cumulative(S, dynamic_param) new_rank = param_dict['new_rank']
new_rank = max(new_rank, 1) new_alpha = param_dict['new_alpha']
new_alpha = float(scale*new_rank)
elif dynamic_method=="sv_fro":
# Calculate new dim and alpha based off sqrt sum of squares
new_rank = index_sv_fro(S, dynamic_param)
new_rank = max(new_rank, 1)
new_alpha = float(scale*new_rank)
if verbose: if verbose:
s_sum = torch.sum(torch.abs(S)) max_ratio = param_dict['max_ratio']
s_rank = torch.sum(torch.abs(S[:new_rank])) sum_retained = param_dict['sum_retained']
fro_retained = param_dict['fro_retained']
S_squared = S.pow(2) if not np.isnan(fro_retained):
s_fro = torch.sqrt(torch.sum(S_squared)) fro_list.append(float(fro_retained))
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
fro_percent = float(s_red_fro/s_fro)
if not np.isnan(fro_percent):
fro_list.append(float(fro_percent))
verbose_str+=f"{block_down_name:75} | " verbose_str+=f"{block_down_name:75} | "
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, fro retained: {fro_percent:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}" verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
if verbose and dynamic_method: if verbose and dynamic_method:
@@ -168,12 +187,11 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
Vh = Vh[:new_rank, :] Vh = Vh[:new_rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()]) # dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE) # hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val # low_val = -hi_val
# U = U.clamp(low_val, hi_val)
U = U.clamp(low_val, hi_val) # Vh = Vh.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d: if conv2d:
U = U.unsqueeze(2).unsqueeze(3) U = U.unsqueeze(2).unsqueeze(3)
@@ -223,7 +241,7 @@ def resize(args):
print("loading Model...") print("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype) lora_sd, metadata = load_state_dict(args.model, merge_dtype)
print("resizing rank...") print("Resizing Lora...")
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
# update metadata # update metadata