mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support sai model spec
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
import math
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from library import sai_model_spec, train_util
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
|
||||
@@ -14,22 +16,26 @@ CLAMP_QUANTILE = 0.99
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
sd = load_file(file_name)
|
||||
metadata = train_util.load_metadata_from_safetensors(file_name)
|
||||
else:
|
||||
sd = torch.load(file_name, map_location="cpu")
|
||||
metadata = {}
|
||||
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
return sd
|
||||
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, state_dict, dtype):
|
||||
def save_to_file(file_name, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(state_dict, file_name)
|
||||
save_file(state_dict, file_name, metadata=metadata)
|
||||
else:
|
||||
torch.save(state_dict, file_name)
|
||||
|
||||
@@ -37,9 +43,17 @@ def save_to_file(file_name, state_dict, dtype):
|
||||
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
||||
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
||||
merged_sd = {}
|
||||
v2 = None
|
||||
base_model = None
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||
|
||||
if lora_metadata is not None:
|
||||
if v2 is None:
|
||||
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string
|
||||
if base_model is None:
|
||||
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
|
||||
|
||||
# merge
|
||||
print(f"merging...")
|
||||
@@ -140,7 +154,16 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
||||
merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
|
||||
|
||||
return merged_lora_sd
|
||||
# build minimum metadata
|
||||
dims = f"{new_rank}"
|
||||
alphas = f"{new_rank}"
|
||||
if new_conv_rank is not None:
|
||||
network_args = {"conv_dim": new_conv_rank, "conv_alpha": new_conv_rank}
|
||||
else:
|
||||
network_args = None
|
||||
metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, network_args)
|
||||
|
||||
return merged_lora_sd, metadata, v2 == "True", base_model
|
||||
|
||||
|
||||
def merge(args):
|
||||
@@ -161,10 +184,32 @@ def merge(args):
|
||||
save_dtype = merge_dtype
|
||||
|
||||
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
|
||||
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
|
||||
state_dict, metadata, v2, base_model = merge_lora_models(
|
||||
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
|
||||
)
|
||||
|
||||
print(f"calculating hashes and creating metadata...")
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
if not args.no_metadata:
|
||||
is_sdxl = base_model is not None and base_model.lower().startswith("sdxl")
|
||||
merged_from = sai_model_spec.build_merged_from(args.models)
|
||||
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(
|
||||
state_dict, v2, v2, is_sdxl, True, False, time.time(), title=title, merged_from=merged_from
|
||||
)
|
||||
if v2:
|
||||
# TODO read sai modelspec
|
||||
print(
|
||||
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
||||
)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, save_dtype)
|
||||
save_to_file(args.save_to, state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
@@ -198,6 +243,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
|
||||
)
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
parser.add_argument(
|
||||
"--no_metadata",
|
||||
action="store_true",
|
||||
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
||||
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
Reference in New Issue
Block a user