mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
keep metadata when resizing
This commit is contained in:
@@ -5,37 +5,40 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file, safe_open
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from library import train_util, model_util
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(file_name, dtype):
|
def load_state_dict(file_name, dtype):
|
||||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
if model_util.is_safetensors(file_name):
|
||||||
sd = load_file(file_name)
|
sd = load_file(file_name)
|
||||||
|
with safe_open(file_name, framework="pt") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
else:
|
else:
|
||||||
sd = torch.load(file_name, map_location='cpu')
|
sd = torch.load(file_name, map_location='cpu')
|
||||||
|
metadata = None
|
||||||
|
|
||||||
for key in list(sd.keys()):
|
for key in list(sd.keys()):
|
||||||
if type(sd[key]) == torch.Tensor:
|
if type(sd[key]) == torch.Tensor:
|
||||||
sd[key] = sd[key].to(dtype)
|
sd[key] = sd[key].to(dtype)
|
||||||
return sd
|
|
||||||
|
return sd, metadata
|
||||||
|
|
||||||
|
|
||||||
def save_to_file(file_name, model, state_dict, dtype):
|
def save_to_file(file_name, model, state_dict, dtype, metadata):
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
for key in list(state_dict.keys()):
|
for key in list(state_dict.keys()):
|
||||||
if type(state_dict[key]) == torch.Tensor:
|
if type(state_dict[key]) == torch.Tensor:
|
||||||
state_dict[key] = state_dict[key].to(dtype)
|
state_dict[key] = state_dict[key].to(dtype)
|
||||||
|
|
||||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
if model_util.is_safetensors(file_name):
|
||||||
save_file(model, file_name)
|
save_file(model, file_name, metadata)
|
||||||
else:
|
else:
|
||||||
torch.save(model, file_name)
|
torch.save(model, file_name)
|
||||||
|
|
||||||
|
|
||||||
|
def resize_lora_model(lora_sd, new_rank, save_dtype, device):
|
||||||
def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
|
|
||||||
print("Loading Model...")
|
|
||||||
lora_sd = load_state_dict(model, merge_dtype)
|
|
||||||
|
|
||||||
network_alpha = None
|
network_alpha = None
|
||||||
network_dim = None
|
network_dim = None
|
||||||
|
|
||||||
@@ -55,7 +58,7 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
|
|||||||
scale = network_alpha/network_dim
|
scale = network_alpha/network_dim
|
||||||
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
||||||
|
|
||||||
print(f"dimension: {network_dim}, alpha: {network_alpha}, new alpha: {new_alpha}")
|
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
|
||||||
|
|
||||||
lora_down_weight = None
|
lora_down_weight = None
|
||||||
lora_up_weight = None
|
lora_up_weight = None
|
||||||
@@ -84,7 +87,7 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
|
|||||||
lora_down_weight = lora_down_weight.squeeze()
|
lora_down_weight = lora_down_weight.squeeze()
|
||||||
lora_up_weight = lora_up_weight.squeeze()
|
lora_up_weight = lora_up_weight.squeeze()
|
||||||
|
|
||||||
if args.device:
|
if device:
|
||||||
org_device = lora_up_weight.device
|
org_device = lora_up_weight.device
|
||||||
lora_up_weight = lora_up_weight.to(args.device)
|
lora_up_weight = lora_up_weight.to(args.device)
|
||||||
lora_down_weight = lora_down_weight.to(args.device)
|
lora_down_weight = lora_down_weight.to(args.device)
|
||||||
@@ -125,7 +128,8 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
|
|||||||
weights_loaded = False
|
weights_loaded = False
|
||||||
|
|
||||||
print("resizing complete")
|
print("resizing complete")
|
||||||
return o_lora_sd
|
return o_lora_sd, network_dim, new_alpha
|
||||||
|
|
||||||
|
|
||||||
def resize(args):
|
def resize(args):
|
||||||
|
|
||||||
@@ -143,10 +147,27 @@ def resize(args):
|
|||||||
if save_dtype is None:
|
if save_dtype is None:
|
||||||
save_dtype = merge_dtype
|
save_dtype = merge_dtype
|
||||||
|
|
||||||
state_dict = resize_lora_model(args.model, args.new_rank, merge_dtype, save_dtype)
|
print("loading Model...")
|
||||||
|
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
||||||
|
|
||||||
|
print("resizing rank...")
|
||||||
|
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device)
|
||||||
|
|
||||||
|
# update metadata
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
comment = metadata.get("ss_training_comment", "")
|
||||||
|
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
||||||
|
metadata["ss_network_dim"] = str(args.new_rank)
|
||||||
|
metadata["ss_network_alpha"] = str(new_alpha)
|
||||||
|
|
||||||
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||||
|
metadata["sshs_model_hash"] = model_hash
|
||||||
|
metadata["sshs_legacy_hash"] = legacy_hash
|
||||||
|
|
||||||
print(f"saving model to: {args.save_to}")
|
print(f"saving model to: {args.save_to}")
|
||||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user