support sdxl

This commit is contained in:
Kohya S
2023-07-24 22:20:21 +09:00
parent e83ee217d3
commit 2b969e9c42

View File

@@ -3,16 +3,18 @@
# Thanks to cloneofsimo! # Thanks to cloneofsimo!
import argparse import argparse
import json
import os import os
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from tqdm import tqdm from tqdm import tqdm
import library.model_util as model_util import library.model_util as model_util
import library.sdxl_model_util as sdxl_model_util
import lora import lora
CLAMP_QUANTILE = 0.99 CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-6 MIN_DIFF = 1e-4
def save_to_file(file_name, model, state_dict, dtype): def save_to_file(file_name, model, state_dict, dtype):
@@ -37,12 +39,35 @@ def svd(args):
return torch.bfloat16 return torch.bfloat16
return None return None
assert args.v2 != args.sdxl or (
not args.v2 and not args.sdxl
), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
if args.v_parameterization is None:
args.v_parameterization = args.v2
save_dtype = str_to_dtype(args.save_precision) save_dtype = str_to_dtype(args.save_precision)
print(f"loading SD model : {args.model_org}") # load models
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) if not args.sdxl:
print(f"loading SD model : {args.model_tuned}") print(f"loading original SD model : {args.model_org}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
text_encoders_o = [text_encoder_o]
print(f"loading tuned SD model : {args.model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
text_encoders_t = [text_encoder_t]
model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization)
else:
print(f"loading original SDXL model : {args.model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.model_org, "cpu"
)
text_encoders_o = [text_encoder_o1, text_encoder_o2]
print(f"loading original SDXL model : {args.model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.model_tuned, "cpu"
)
text_encoders_t = [text_encoder_t1, text_encoder_t2]
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9
# create LoRA network to extract weights: Use dim (rank) as alpha # create LoRA network to extract weights: Use dim (rank) as alpha
if args.conv_dim is None: if args.conv_dim is None:
@@ -50,8 +75,8 @@ def svd(args):
else: else:
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim} kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs) lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs) lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs)
assert len(lora_network_o.text_encoder_loras) == len( assert len(lora_network_o.text_encoder_loras) == len(
lora_network_t.text_encoder_loras lora_network_t.text_encoder_loras
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース " ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース "
@@ -66,8 +91,9 @@ def svd(args):
diff = module_t.weight - module_o.weight diff = module_t.weight - module_o.weight
# Text Encoder might be same # Text Encoder might be same
if torch.max(torch.abs(diff)) > MIN_DIFF: if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF:
text_encoder_different = True text_encoder_different = True
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}")
diff = diff.float() diff = diff.float()
diffs[lora_name] = diff diffs[lora_name] = diff
@@ -146,8 +172,8 @@ def svd(args):
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0])
# load state dict to LoRA and save it # load state dict to LoRA and save it
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd) lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
info = lora_network_save.load_state_dict(lora_sd) info = lora_network_save.load_state_dict(lora_sd)
print(f"Loading extracted LoRA weights: {info}") print(f"Loading extracted LoRA weights: {info}")
@@ -157,7 +183,19 @@ def svd(args):
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
# minimum metadata # minimum metadata
metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} net_kwargs = {}
if args.conv_dim is not None:
net_kwargs["conv_dim"] = args.conv_dim
net_kwargs["conv_alpha"] = args.conv_dim
metadata = {
"ss_v2": str(args.v2),
"ss_base_model_version": model_version,
"ss_network_module": "networks.lora",
"ss_network_dim": str(args.dim),
"ss_network_alpha": str(args.dim),
"ss_network_args": json.dumps(net_kwargs),
}
lora_network_save.save_weights(args.save_to, save_dtype, metadata) lora_network_save.save_weights(args.save_to, save_dtype, metadata)
print(f"LoRA weights are saved to: {args.save_to}") print(f"LoRA weights are saved to: {args.save_to}")
@@ -166,6 +204,15 @@ def svd(args):
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
parser.add_argument(
"--v_parameterization",
type=bool,
default=None,
help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する省略時はv2と同じ",
)
parser.add_argument(
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
)
parser.add_argument( parser.add_argument(
"--save_precision", "--save_precision",
type=str, type=str,