mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
use **kwargs and change svd() calling convention to make svd() reusable
* add required attributes to model_org, model_tuned, save_to * set "*_alpha" using str(float(foo))
This commit is contained in:
@@ -29,7 +29,7 @@ def save_to_file(file_name, model, state_dict, dtype):
|
|||||||
torch.save(model, file_name)
|
torch.save(model, file_name)
|
||||||
|
|
||||||
|
|
||||||
def svd(args):
|
def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, no_metadata=False):
|
||||||
def str_to_dtype(p):
|
def str_to_dtype(p):
|
||||||
if p == "float":
|
if p == "float":
|
||||||
return torch.float
|
return torch.float
|
||||||
@@ -39,44 +39,44 @@ def svd(args):
|
|||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
return None
|
return None
|
||||||
|
|
||||||
assert args.v2 != args.sdxl or (
|
assert v2 != sdxl or (
|
||||||
not args.v2 and not args.sdxl
|
not v2 and not sdxl
|
||||||
), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
|
), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
|
||||||
if args.v_parameterization is None:
|
if v_parameterization is None:
|
||||||
args.v_parameterization = args.v2
|
v_parameterization = v2
|
||||||
|
|
||||||
save_dtype = str_to_dtype(args.save_precision)
|
save_dtype = str_to_dtype(save_precision)
|
||||||
|
|
||||||
# load models
|
# load models
|
||||||
if not args.sdxl:
|
if not sdxl:
|
||||||
print(f"loading original SD model : {args.model_org}")
|
print(f"loading original SD model : {model_org}")
|
||||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
|
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
||||||
text_encoders_o = [text_encoder_o]
|
text_encoders_o = [text_encoder_o]
|
||||||
print(f"loading tuned SD model : {args.model_tuned}")
|
print(f"loading tuned SD model : {model_tuned}")
|
||||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
||||||
text_encoders_t = [text_encoder_t]
|
text_encoders_t = [text_encoder_t]
|
||||||
model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization)
|
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
|
||||||
else:
|
else:
|
||||||
print(f"loading original SDXL model : {args.model_org}")
|
print(f"loading original SDXL model : {model_org}")
|
||||||
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu"
|
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu"
|
||||||
)
|
)
|
||||||
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
||||||
print(f"loading original SDXL model : {args.model_tuned}")
|
print(f"loading original SDXL model : {model_tuned}")
|
||||||
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu"
|
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu"
|
||||||
)
|
)
|
||||||
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
||||||
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
||||||
|
|
||||||
# create LoRA network to extract weights: Use dim (rank) as alpha
|
# create LoRA network to extract weights: Use dim (rank) as alpha
|
||||||
if args.conv_dim is None:
|
if conv_dim is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
else:
|
else:
|
||||||
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
|
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim}
|
||||||
|
|
||||||
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs)
|
lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs)
|
||||||
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs)
|
lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs)
|
||||||
assert len(lora_network_o.text_encoder_loras) == len(
|
assert len(lora_network_o.text_encoder_loras) == len(
|
||||||
lora_network_t.text_encoder_loras
|
lora_network_t.text_encoder_loras
|
||||||
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
||||||
@@ -120,16 +120,16 @@ def svd(args):
|
|||||||
lora_weights = {}
|
lora_weights = {}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for lora_name, mat in tqdm(list(diffs.items())):
|
for lora_name, mat in tqdm(list(diffs.items())):
|
||||||
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
||||||
conv2d = len(mat.size()) == 4
|
conv2d = len(mat.size()) == 4
|
||||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||||
|
|
||||||
rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
|
rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
|
||||||
out_dim, in_dim = mat.size()[0:2]
|
out_dim, in_dim = mat.size()[0:2]
|
||||||
|
|
||||||
if args.device:
|
if device:
|
||||||
mat = mat.to(args.device)
|
mat = mat.to(device)
|
||||||
|
|
||||||
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
||||||
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||||
@@ -178,34 +178,34 @@ def svd(args):
|
|||||||
info = lora_network_save.load_state_dict(lora_sd)
|
info = lora_network_save.load_state_dict(lora_sd)
|
||||||
print(f"Loading extracted LoRA weights: {info}")
|
print(f"Loading extracted LoRA weights: {info}")
|
||||||
|
|
||||||
dir_name = os.path.dirname(args.save_to)
|
dir_name = os.path.dirname(save_to)
|
||||||
if dir_name and not os.path.exists(dir_name):
|
if dir_name and not os.path.exists(dir_name):
|
||||||
os.makedirs(dir_name, exist_ok=True)
|
os.makedirs(dir_name, exist_ok=True)
|
||||||
|
|
||||||
# minimum metadata
|
# minimum metadata
|
||||||
net_kwargs = {}
|
net_kwargs = {}
|
||||||
if args.conv_dim is not None:
|
if conv_dim is not None:
|
||||||
net_kwargs["conv_dim"] = args.conv_dim
|
net_kwargs["conv_dim"] = str(conv_dim)
|
||||||
net_kwargs["conv_alpha"] = args.conv_dim
|
net_kwargs["conv_alpha"] = str(float(conv_dim))
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"ss_v2": str(args.v2),
|
"ss_v2": str(v2),
|
||||||
"ss_base_model_version": model_version,
|
"ss_base_model_version": model_version,
|
||||||
"ss_network_module": "networks.lora",
|
"ss_network_module": "networks.lora",
|
||||||
"ss_network_dim": str(args.dim),
|
"ss_network_dim": str(dim),
|
||||||
"ss_network_alpha": str(args.dim),
|
"ss_network_alpha": str(float(dim)),
|
||||||
"ss_network_args": json.dumps(net_kwargs),
|
"ss_network_args": json.dumps(net_kwargs),
|
||||||
}
|
}
|
||||||
|
|
||||||
if not args.no_metadata:
|
if not no_metadata:
|
||||||
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
title = os.path.splitext(os.path.basename(save_to))[0]
|
||||||
sai_metadata = sai_model_spec.build_metadata(
|
sai_metadata = sai_model_spec.build_metadata(
|
||||||
None, args.v2, args.v_parameterization, args.sdxl, True, False, time.time(), title=title
|
None, v2, v_parameterization, sdxl, True, False, time.time(), title=title
|
||||||
)
|
)
|
||||||
metadata.update(sai_metadata)
|
metadata.update(sai_metadata)
|
||||||
|
|
||||||
lora_network_save.save_weights(args.save_to, save_dtype, metadata)
|
lora_network_save.save_weights(save_to, save_dtype, metadata)
|
||||||
print(f"LoRA weights are saved to: {args.save_to}")
|
print(f"LoRA weights are saved to: {save_to}")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
@@ -213,7 +213,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
|
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--v_parameterization",
|
"--v_parameterization",
|
||||||
type=bool,
|
action="store_true",
|
||||||
default=None,
|
default=None,
|
||||||
help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)",
|
help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)",
|
||||||
)
|
)
|
||||||
@@ -231,16 +231,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
"--model_org",
|
"--model_org",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
|
required=True,
|
||||||
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
|
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_tuned",
|
"--model_tuned",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
|
required=True,
|
||||||
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
|
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
"--save_to", type=str, default=None, required=True, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
||||||
)
|
)
|
||||||
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -264,4 +266,4 @@ if __name__ == "__main__":
|
|||||||
parser = setup_parser()
|
parser = setup_parser()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
svd(args)
|
svd(**vars(args))
|
||||||
|
|||||||
Reference in New Issue
Block a user