keep metadata when resizing

This commit is contained in:
Kohya S
2023-02-10 22:55:00 +09:00
parent d2da3c4236
commit c7406d6b27

View File

@@ -5,37 +5,40 @@
import argparse import argparse
import os import os
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm from tqdm import tqdm
from library import train_util, model_util
def load_state_dict(file_name, dtype): def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors': if model_util.is_safetensors(file_name):
sd = load_file(file_name) sd = load_file(file_name)
with safe_open(file_name, framework="pt") as f:
metadata = f.metadata()
else: else:
sd = torch.load(file_name, map_location='cpu') sd = torch.load(file_name, map_location='cpu')
metadata = None
for key in list(sd.keys()): for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor: if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype) sd[key] = sd[key].to(dtype)
return sd
return sd, metadata
def save_to_file(file_name, model, state_dict, dtype): def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None: if dtype is not None:
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor: if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype) state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors': if model_util.is_safetensors(file_name):
save_file(model, file_name) save_file(model, file_name, metadata)
else: else:
torch.save(model, file_name) torch.save(model, file_name)
def resize_lora_model(lora_sd, new_rank, save_dtype, device):
def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
print("Loading Model...")
lora_sd = load_state_dict(model, merge_dtype)
network_alpha = None network_alpha = None
network_dim = None network_dim = None
@@ -55,7 +58,7 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
scale = network_alpha/network_dim scale = network_alpha/network_dim
new_alpha = float(scale*new_rank) # calculate new alpha from scale new_alpha = float(scale*new_rank) # calculate new alpha from scale
print(f"dimension: {network_dim}, alpha: {network_alpha}, new alpha: {new_alpha}") print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
lora_down_weight = None lora_down_weight = None
lora_up_weight = None lora_up_weight = None
@@ -84,7 +87,7 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
lora_down_weight = lora_down_weight.squeeze() lora_down_weight = lora_down_weight.squeeze()
lora_up_weight = lora_up_weight.squeeze() lora_up_weight = lora_up_weight.squeeze()
if args.device: if device:
org_device = lora_up_weight.device org_device = lora_up_weight.device
lora_up_weight = lora_up_weight.to(args.device) lora_up_weight = lora_up_weight.to(args.device)
lora_down_weight = lora_down_weight.to(args.device) lora_down_weight = lora_down_weight.to(args.device)
@@ -125,7 +128,8 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
weights_loaded = False weights_loaded = False
print("resizing complete") print("resizing complete")
return o_lora_sd return o_lora_sd, network_dim, new_alpha
def resize(args): def resize(args):
@@ -143,10 +147,27 @@ def resize(args):
if save_dtype is None: if save_dtype is None:
save_dtype = merge_dtype save_dtype = merge_dtype
state_dict = resize_lora_model(args.model, args.new_rank, merge_dtype, save_dtype) print("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
print("resizing rank...")
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device)
# update metadata
if metadata is None:
metadata = {}
comment = metadata.get("ss_training_comment", "")
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
metadata["ss_network_dim"] = str(args.new_rank)
metadata["ss_network_alpha"] = str(new_alpha)
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
print(f"saving model to: {args.save_to}") print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype) save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
if __name__ == '__main__': if __name__ == '__main__':