use **kwargs and change svd() calling convention to make svd() reusable

* add required attributes to model_org, model_tuned, save_to
 * set "*_alpha" using str(float(foo))
This commit is contained in:
Won-Kyu Park
2023-11-08 19:09:23 +09:00
parent 6231aa91e2
commit e20e9f61ac

View File

@@ -29,7 +29,7 @@ def save_to_file(file_name, model, state_dict, dtype):
torch.save(model, file_name) torch.save(model, file_name)
def svd(args): def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, no_metadata=False):
def str_to_dtype(p): def str_to_dtype(p):
if p == "float": if p == "float":
return torch.float return torch.float
@@ -39,44 +39,44 @@ def svd(args):
return torch.bfloat16 return torch.bfloat16
return None return None
assert args.v2 != args.sdxl or ( assert v2 != sdxl or (
not args.v2 and not args.sdxl not v2 and not sdxl
), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" ), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
if args.v_parameterization is None: if v_parameterization is None:
args.v_parameterization = args.v2 v_parameterization = v2
save_dtype = str_to_dtype(args.save_precision) save_dtype = str_to_dtype(save_precision)
# load models # load models
if not args.sdxl: if not sdxl:
print(f"loading original SD model : {args.model_org}") print(f"loading original SD model : {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o] text_encoders_o = [text_encoder_o]
print(f"loading tuned SD model : {args.model_tuned}") print(f"loading tuned SD model : {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t] text_encoders_t = [text_encoder_t]
model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization) model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else: else:
print(f"loading original SDXL model : {args.model_org}") print(f"loading original SDXL model : {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu" sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu"
) )
text_encoders_o = [text_encoder_o1, text_encoder_o2] text_encoders_o = [text_encoder_o1, text_encoder_o2]
print(f"loading original SDXL model : {args.model_tuned}") print(f"loading original SDXL model : {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu" sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu"
) )
text_encoders_t = [text_encoder_t1, text_encoder_t2] text_encoders_t = [text_encoder_t1, text_encoder_t2]
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# create LoRA network to extract weights: Use dim (rank) as alpha # create LoRA network to extract weights: Use dim (rank) as alpha
if args.conv_dim is None: if conv_dim is None:
kwargs = {} kwargs = {}
else: else:
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim} kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim}
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs) lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs) lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs)
assert len(lora_network_o.text_encoder_loras) == len( assert len(lora_network_o.text_encoder_loras) == len(
lora_network_t.text_encoder_loras lora_network_t.text_encoder_loras
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース " ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース "
@@ -120,16 +120,16 @@ def svd(args):
lora_weights = {} lora_weights = {}
with torch.no_grad(): with torch.no_grad():
for lora_name, mat in tqdm(list(diffs.items())): for lora_name, mat in tqdm(list(diffs.items())):
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3 # if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
conv2d = len(mat.size()) == 4 conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4] kernel_size = None if not conv2d else mat.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1) conv2d_3x3 = conv2d and kernel_size != (1, 1)
rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
out_dim, in_dim = mat.size()[0:2] out_dim, in_dim = mat.size()[0:2]
if args.device: if device:
mat = mat.to(args.device) mat = mat.to(device)
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
@@ -178,34 +178,34 @@ def svd(args):
info = lora_network_save.load_state_dict(lora_sd) info = lora_network_save.load_state_dict(lora_sd)
print(f"Loading extracted LoRA weights: {info}") print(f"Loading extracted LoRA weights: {info}")
dir_name = os.path.dirname(args.save_to) dir_name = os.path.dirname(save_to)
if dir_name and not os.path.exists(dir_name): if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
# minimum metadata # minimum metadata
net_kwargs = {} net_kwargs = {}
if args.conv_dim is not None: if conv_dim is not None:
net_kwargs["conv_dim"] = args.conv_dim net_kwargs["conv_dim"] = str(conv_dim)
net_kwargs["conv_alpha"] = args.conv_dim net_kwargs["conv_alpha"] = str(float(conv_dim))
metadata = { metadata = {
"ss_v2": str(args.v2), "ss_v2": str(v2),
"ss_base_model_version": model_version, "ss_base_model_version": model_version,
"ss_network_module": "networks.lora", "ss_network_module": "networks.lora",
"ss_network_dim": str(args.dim), "ss_network_dim": str(dim),
"ss_network_alpha": str(args.dim), "ss_network_alpha": str(float(dim)),
"ss_network_args": json.dumps(net_kwargs), "ss_network_args": json.dumps(net_kwargs),
} }
if not args.no_metadata: if not no_metadata:
title = os.path.splitext(os.path.basename(args.save_to))[0] title = os.path.splitext(os.path.basename(save_to))[0]
sai_metadata = sai_model_spec.build_metadata( sai_metadata = sai_model_spec.build_metadata(
None, args.v2, args.v_parameterization, args.sdxl, True, False, time.time(), title=title None, v2, v_parameterization, sdxl, True, False, time.time(), title=title
) )
metadata.update(sai_metadata) metadata.update(sai_metadata)
lora_network_save.save_weights(args.save_to, save_dtype, metadata) lora_network_save.save_weights(save_to, save_dtype, metadata)
print(f"LoRA weights are saved to: {args.save_to}") print(f"LoRA weights are saved to: {save_to}")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:
@@ -213,7 +213,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
parser.add_argument( parser.add_argument(
"--v_parameterization", "--v_parameterization",
type=bool, action="store_true",
default=None, default=None,
help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する省略時はv2と同じ", help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する省略時はv2と同じ",
) )
@@ -231,16 +231,18 @@ def setup_parser() -> argparse.ArgumentParser:
"--model_org", "--model_org",
type=str, type=str,
default=None, default=None,
required=True,
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors", help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
) )
parser.add_argument( parser.add_argument(
"--model_tuned", "--model_tuned",
type=str, type=str,
default=None, default=None,
required=True,
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors", help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors",
) )
parser.add_argument( parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" "--save_to", type=str, default=None, required=True, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
) )
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4") parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4")
parser.add_argument( parser.add_argument(
@@ -264,4 +266,4 @@ if __name__ == "__main__":
parser = setup_parser() parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
svd(args) svd(**vars(args))