add dtype

This commit is contained in:
Kohya S
2023-02-16 21:59:35 +09:00
parent 39aa390d2b
commit 8590d5dbca

View File

@@ -19,16 +19,16 @@ CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-6
def save_to_file(file_name, model, state_dict, dtype):
def save_to_file(file_name, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
save_file(state_dict, file_name)
else:
torch.save(model, file_name)
torch.save(state_dict, file_name)
def svd(args):
@@ -45,15 +45,15 @@ def svd(args):
# Diffusersのキーに変換するため、original sdとcontrol sdからU-Netに重みを読み込む ###############
# original sdをDiffusersに読み込む
# original sdをDiffusersのU-Netに読み込む
print(f"loading original SD model : {args.model_org}")
org_text_encoder, _, org_unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
_, _, org_unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
org_sd = torch.load(args.model_org, map_location='cpu')
if 'state_dict' in org_sd:
org_sd = org_sd['state_dict']
# control sdからキー変換しつつU-Netに対応する部分のみ取り出
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのuU-Netに読み込む
print(f"loading control SD model : {args.model_tuned}")
ctrl_sd = torch.load(args.model_tuned, map_location='cpu')
@@ -65,8 +65,8 @@ def svd(args):
continue
ctrl_unet_sd[unet_key] = ctrl_sd[key]
unet_config = model_util.create_unet_diffusers_config(False)
ctrl_unet_sd_du = model_util.convert_ldm_unet_checkpoint(False, ctrl_unet_sd, unet_config)
unet_config = model_util.create_unet_diffusers_config(args.v2)
ctrl_unet_sd_du = model_util.convert_ldm_unet_checkpoint(args.v2, ctrl_unet_sd, unet_config)
# load weights to U-Net
ctrl_unet = UNet2DConditionModel(**unet_config)
@@ -75,8 +75,6 @@ def svd(args):
# LoRAに対応する部分のU-Netの重みを読み込む #################################
org_unet_sd_du = org_unet.state_dict()
diffs = {}
for (org_name, org_module), (ctrl_name, ctrl_module) in zip(org_unet.named_modules(), ctrl_unet.named_modules()):
if org_module.__class__.__name__ != "Linear" and org_module.__class__.__name__ != "Conv2d":
@@ -172,6 +170,7 @@ def svd(args):
if 'zero_convs' in key or 'input_hint_block' in key or 'middle_block_out' in key:
ctrl_lora_sd[key] = value
# verify state dict by loading it
info = lora_network.load_state_dict(ctrl_lora_sd)
print(f"loading control lora sd: {info}")
@@ -183,7 +182,7 @@ def svd(args):
# metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
# lora_network.save_weights(args.save_to, save_dtype, metadata)
save_file(ctrl_lora_sd, args.save_to)
save_to_file(args.save_to, ctrl_lora_sd, save_dtype)
print(f"LoRA weights are saved to: {args.save_to}")