mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add verbosity option for resize_lora.py
add --verbose flag to print additional statistics during resize_lora function correct some parameter references in resize_lora_model function
This commit is contained in:
@@ -38,9 +38,10 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
|
|||||||
torch.save(model, file_name)
|
torch.save(model, file_name)
|
||||||
|
|
||||||
|
|
||||||
def resize_lora_model(lora_sd, new_rank, save_dtype, device):
|
def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
||||||
network_alpha = None
|
network_alpha = None
|
||||||
network_dim = None
|
network_dim = None
|
||||||
|
verbose_str = "\n"
|
||||||
|
|
||||||
CLAMP_QUANTILE = 0.99
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
@@ -96,6 +97,12 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device):
|
|||||||
|
|
||||||
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
s_sum = torch.sum(torch.abs(S))
|
||||||
|
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
||||||
|
verbose_str+=f"{block_down_name:76} | "
|
||||||
|
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}%, max(S) to max(S_dropped) ratio: {S[0]/S[new_rank]:0.1f}\n"
|
||||||
|
|
||||||
U = U[:, :new_rank]
|
U = U[:, :new_rank]
|
||||||
S = S[:new_rank]
|
S = S[:new_rank]
|
||||||
U = U @ torch.diag(S)
|
U = U @ torch.diag(S)
|
||||||
@@ -113,7 +120,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device):
|
|||||||
U = U.unsqueeze(2).unsqueeze(3)
|
U = U.unsqueeze(2).unsqueeze(3)
|
||||||
Vh = Vh.unsqueeze(2).unsqueeze(3)
|
Vh = Vh.unsqueeze(2).unsqueeze(3)
|
||||||
|
|
||||||
if args.device:
|
if device:
|
||||||
U = U.to(org_device)
|
U = U.to(org_device)
|
||||||
Vh = Vh.to(org_device)
|
Vh = Vh.to(org_device)
|
||||||
|
|
||||||
@@ -127,6 +134,8 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device):
|
|||||||
lora_up_weight = None
|
lora_up_weight = None
|
||||||
weights_loaded = False
|
weights_loaded = False
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(verbose_str)
|
||||||
print("resizing complete")
|
print("resizing complete")
|
||||||
return o_lora_sd, network_dim, new_alpha
|
return o_lora_sd, network_dim, new_alpha
|
||||||
|
|
||||||
@@ -151,7 +160,7 @@ def resize(args):
|
|||||||
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
||||||
|
|
||||||
print("resizing rank...")
|
print("resizing rank...")
|
||||||
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device)
|
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
|
||||||
|
|
||||||
# update metadata
|
# update metadata
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
@@ -182,6 +191,8 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--model", type=str, default=None,
|
parser.add_argument("--model", type=str, default=None,
|
||||||
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
|
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
|
||||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||||
|
parser.add_argument("--verbose", action="store_true",
|
||||||
|
help="Display verbose resizing information")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
resize(args)
|
resize(args)
|
||||||
|
|||||||
Reference in New Issue
Block a user