Merge branch 'dev' into resize_lora-add-rank-for-conv

This commit is contained in:
Kohya S
2024-02-24 20:09:38 +09:00
committed by GitHub
72 changed files with 6491 additions and 1529 deletions

View File

@@ -9,43 +9,52 @@ from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import train_util
import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
MIN_SV = 1e-6
# Model save and load functions
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
if model_util.is_safetensors(file_name):
sd = load_file(file_name)
metadata = train_util.load_metadata_from_safetensors(file_name)
with safe_open(file_name, framework="pt") as f:
metadata = f.metadata()
else:
sd = torch.load(file_name, map_location="cpu")
metadata = {}
metadata = None
for key in list(sd.keys()):
if isinstance(sd[key], torch.Tensor):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd, metadata
def save_to_file(file_name, state_dict, dtype, metadata):
def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if isinstance(state_dict[key], torch.Tensor):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(state_dict, file_name, metadata=metadata)
if model_util.is_safetensors(file_name):
save_file(model, file_name, metadata)
else:
torch.save(state_dict, file_name)
torch.save(model, file_name)
# Indexing functions
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
index = max(1, min(index, len(S)-1))
index = max(1, min(index, len(S) - 1))
return index
@@ -136,26 +145,27 @@ def merge_linear(lora_down, lora_up, device):
# Calculate new rank
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
param_dict = {}
if dynamic_method == "sv_ratio":
# Calculate new dim and alpha based off ratio
new_rank = index_sv_ratio(S, dynamic_param) + 1
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
elif dynamic_method == "sv_cumulative":
# Calculate new dim and alpha based off cumulative sum
new_rank = index_sv_cumulative(S, dynamic_param) + 1
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
elif dynamic_method == "sv_fro":
# Calculate new dim and alpha based off sqrt sum of squares
new_rank = index_sv_fro(S, dynamic_param) + 1
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
else:
new_rank = rank
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
new_rank = 1
@@ -171,13 +181,13 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
S_squared = S.pow(2)
s_fro = torch.sqrt(torch.sum(S_squared))
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
fro_percent = float(s_red_fro/s_fro)
fro_percent = float(s_red_fro / s_fro)
param_dict["new_rank"] = new_rank
param_dict["new_alpha"] = new_alpha
param_dict["sum_retained"] = (s_rank)/s_sum
param_dict["sum_retained"] = (s_rank) / s_sum
param_dict["fro_retained"] = fro_percent
param_dict["max_ratio"] = S[0]/S[new_rank - 1]
param_dict["max_ratio"] = S[0] / S[new_rank - 1]
return param_dict
@@ -202,7 +212,7 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
scale = network_alpha/network_dim
if dynamic_method:
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
logger.info(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
lora_down_weight = None
lora_up_weight = None
@@ -271,10 +281,10 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
del param_dict
if verbose:
print(verbose_str)
logger.info(verbose_str)
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
print("resizing complete")
logger.info(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
logger.info("resizing complete")
return o_lora_sd, network_dim, new_alpha
@@ -307,10 +317,10 @@ def resize(args):
if save_dtype is None:
save_dtype = merge_dtype
print("loading Model...")
logger.info("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
print("Resizing Lora...")
logger.info("Resizing Lora...")
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
# update metadata
@@ -332,8 +342,8 @@ def resize(args):
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, metadata)
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
def setup_parser() -> argparse.ArgumentParser:
@@ -359,7 +369,7 @@ def setup_parser() -> argparse.ArgumentParser:
return parser
if __name__ == '__main__':
parser = setup_parser()