mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'dev' into resize_lora-add-rank-for-conv
This commit is contained in:
@@ -9,43 +9,52 @@ from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from library import train_util
|
||||
import numpy as np
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MIN_SV = 1e-6
|
||||
|
||||
# Model save and load functions
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
if model_util.is_safetensors(file_name):
|
||||
sd = load_file(file_name)
|
||||
metadata = train_util.load_metadata_from_safetensors(file_name)
|
||||
with safe_open(file_name, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
else:
|
||||
sd = torch.load(file_name, map_location="cpu")
|
||||
metadata = {}
|
||||
metadata = None
|
||||
|
||||
for key in list(sd.keys()):
|
||||
if isinstance(sd[key], torch.Tensor):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
|
||||
return sd, metadata
|
||||
|
||||
def save_to_file(file_name, state_dict, dtype, metadata):
|
||||
def save_to_file(file_name, model, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if isinstance(state_dict[key], torch.Tensor):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(state_dict, file_name, metadata=metadata)
|
||||
if model_util.is_safetensors(file_name):
|
||||
save_file(model, file_name, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file_name)
|
||||
torch.save(model, file_name)
|
||||
|
||||
# Indexing functions
|
||||
|
||||
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
|
||||
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
index = max(1, min(index, len(S)-1))
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
@@ -136,26 +145,27 @@ def merge_linear(lora_down, lora_up, device):
|
||||
|
||||
# Calculate new rank
|
||||
|
||||
|
||||
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
param_dict = {}
|
||||
|
||||
if dynamic_method == "sv_ratio":
|
||||
# Calculate new dim and alpha based off ratio
|
||||
new_rank = index_sv_ratio(S, dynamic_param) + 1
|
||||
new_alpha = float(scale*new_rank)
|
||||
new_alpha = float(scale * new_rank)
|
||||
|
||||
elif dynamic_method == "sv_cumulative":
|
||||
# Calculate new dim and alpha based off cumulative sum
|
||||
new_rank = index_sv_cumulative(S, dynamic_param) + 1
|
||||
new_alpha = float(scale*new_rank)
|
||||
new_alpha = float(scale * new_rank)
|
||||
|
||||
elif dynamic_method == "sv_fro":
|
||||
# Calculate new dim and alpha based off sqrt sum of squares
|
||||
new_rank = index_sv_fro(S, dynamic_param) + 1
|
||||
new_alpha = float(scale*new_rank)
|
||||
new_alpha = float(scale * new_rank)
|
||||
else:
|
||||
new_rank = rank
|
||||
new_alpha = float(scale*new_rank)
|
||||
new_alpha = float(scale * new_rank)
|
||||
|
||||
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
||||
new_rank = 1
|
||||
@@ -171,13 +181,13 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
S_squared = S.pow(2)
|
||||
s_fro = torch.sqrt(torch.sum(S_squared))
|
||||
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
||||
fro_percent = float(s_red_fro/s_fro)
|
||||
fro_percent = float(s_red_fro / s_fro)
|
||||
|
||||
param_dict["new_rank"] = new_rank
|
||||
param_dict["new_alpha"] = new_alpha
|
||||
param_dict["sum_retained"] = (s_rank)/s_sum
|
||||
param_dict["sum_retained"] = (s_rank) / s_sum
|
||||
param_dict["fro_retained"] = fro_percent
|
||||
param_dict["max_ratio"] = S[0]/S[new_rank - 1]
|
||||
param_dict["max_ratio"] = S[0] / S[new_rank - 1]
|
||||
|
||||
return param_dict
|
||||
|
||||
@@ -202,7 +212,7 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
scale = network_alpha/network_dim
|
||||
|
||||
if dynamic_method:
|
||||
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
|
||||
logger.info(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
|
||||
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
@@ -271,10 +281,10 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
del param_dict
|
||||
|
||||
if verbose:
|
||||
print(verbose_str)
|
||||
logger.info(verbose_str)
|
||||
|
||||
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
||||
print("resizing complete")
|
||||
logger.info(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
||||
logger.info("resizing complete")
|
||||
return o_lora_sd, network_dim, new_alpha
|
||||
|
||||
|
||||
@@ -307,10 +317,10 @@ def resize(args):
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
|
||||
print("loading Model...")
|
||||
logger.info("loading Model...")
|
||||
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
||||
|
||||
print("Resizing Lora...")
|
||||
logger.info("Resizing Lora...")
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
|
||||
|
||||
# update metadata
|
||||
@@ -332,8 +342,8 @@ def resize(args):
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, save_dtype, metadata)
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
@@ -359,7 +369,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user