diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 3510b553..0bc1afe0 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -3,16 +3,18 @@ # Thanks to cloneofsimo! import argparse +import json import os import torch from safetensors.torch import load_file, save_file from tqdm import tqdm import library.model_util as model_util +import library.sdxl_model_util as sdxl_model_util import lora CLAMP_QUANTILE = 0.99 -MIN_DIFF = 1e-6 +MIN_DIFF = 1e-4 def save_to_file(file_name, model, state_dict, dtype): @@ -37,12 +39,35 @@ def svd(args): return torch.bfloat16 return None + assert args.v2 != args.sdxl or ( + not args.v2 and not args.sdxl + ), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" + if args.v_parameterization is None: + args.v_parameterization = args.v2 + save_dtype = str_to_dtype(args.save_precision) - print(f"loading SD model : {args.model_org}") - text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) - print(f"loading SD model : {args.model_tuned}") - text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) + # load models + if not args.sdxl: + print(f"loading original SD model : {args.model_org}") + text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) + text_encoders_o = [text_encoder_o] + print(f"loading tuned SD model : {args.model_tuned}") + text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) + text_encoders_t = [text_encoder_t] + model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization) + else: + print(f"loading original SDXL model : {args.model_org}") + text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.model_org, "cpu" + ) + text_encoders_o = [text_encoder_o1, text_encoder_o2] + print(f"loading original SDXL model : {args.model_tuned}") + text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.model_tuned, "cpu" + ) + text_encoders_t = [text_encoder_t1, text_encoder_t2] + model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9 # create LoRA network to extract weights: Use dim (rank) as alpha if args.conv_dim is None: @@ -50,8 +75,8 @@ def svd(args): else: kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim} - lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs) - lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs) + lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs) + lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs) assert len(lora_network_o.text_encoder_loras) == len( lora_network_t.text_encoder_loras ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " @@ -66,8 +91,9 @@ def svd(args): diff = module_t.weight - module_o.weight # Text Encoder might be same - if torch.max(torch.abs(diff)) > MIN_DIFF: + if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF: text_encoder_different = True + print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}") diff = diff.float() diffs[lora_name] = diff @@ -146,8 +172,8 @@ def svd(args): lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) # load state dict to LoRA and save it - lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd) - lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict + lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd) + lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict info = lora_network_save.load_state_dict(lora_sd) print(f"Loading extracted LoRA weights: {info}") @@ -157,7 +183,19 @@ def svd(args): os.makedirs(dir_name, exist_ok=True) # minimum metadata - metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} + net_kwargs = {} + if args.conv_dim is not None: + net_kwargs["conv_dim"] = args.conv_dim + net_kwargs["conv_alpha"] = args.conv_dim + + metadata = { + "ss_v2": str(args.v2), + "ss_base_model_version": model_version, + "ss_network_module": "networks.lora", + "ss_network_dim": str(args.dim), + "ss_network_alpha": str(args.dim), + "ss_network_args": json.dumps(net_kwargs), + } lora_network_save.save_weights(args.save_to, save_dtype, metadata) print(f"LoRA weights are saved to: {args.save_to}") @@ -166,6 +204,15 @@ def svd(args): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") + parser.add_argument( + "--v_parameterization", + type=bool, + default=None, + help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)", + ) + parser.add_argument( + "--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む" + ) parser.add_argument( "--save_precision", type=str,