mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add dtype
This commit is contained in:
@@ -19,16 +19,16 @@ CLAMP_QUANTILE = 0.99
|
|||||||
MIN_DIFF = 1e-6
|
MIN_DIFF = 1e-6
|
||||||
|
|
||||||
|
|
||||||
def save_to_file(file_name, model, state_dict, dtype):
|
def save_to_file(file_name, state_dict, dtype):
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
for key in list(state_dict.keys()):
|
for key in list(state_dict.keys()):
|
||||||
if type(state_dict[key]) == torch.Tensor:
|
if type(state_dict[key]) == torch.Tensor:
|
||||||
state_dict[key] = state_dict[key].to(dtype)
|
state_dict[key] = state_dict[key].to(dtype)
|
||||||
|
|
||||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||||
save_file(model, file_name)
|
save_file(state_dict, file_name)
|
||||||
else:
|
else:
|
||||||
torch.save(model, file_name)
|
torch.save(state_dict, file_name)
|
||||||
|
|
||||||
|
|
||||||
def svd(args):
|
def svd(args):
|
||||||
@@ -45,15 +45,15 @@ def svd(args):
|
|||||||
|
|
||||||
# Diffusersのキーに変換するため、original sdとcontrol sdからU-Netに重みを読み込む ###############
|
# Diffusersのキーに変換するため、original sdとcontrol sdからU-Netに重みを読み込む ###############
|
||||||
|
|
||||||
# original sdをDiffusersに読み込む
|
# original sdをDiffusersのU-Netに読み込む
|
||||||
print(f"loading original SD model : {args.model_org}")
|
print(f"loading original SD model : {args.model_org}")
|
||||||
org_text_encoder, _, org_unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
|
_, _, org_unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
|
||||||
|
|
||||||
org_sd = torch.load(args.model_org, map_location='cpu')
|
org_sd = torch.load(args.model_org, map_location='cpu')
|
||||||
if 'state_dict' in org_sd:
|
if 'state_dict' in org_sd:
|
||||||
org_sd = org_sd['state_dict']
|
org_sd = org_sd['state_dict']
|
||||||
|
|
||||||
# control sdからキー変換しつつU-Netに対応する部分のみ取り出す
|
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのuU-Netに読み込む
|
||||||
print(f"loading control SD model : {args.model_tuned}")
|
print(f"loading control SD model : {args.model_tuned}")
|
||||||
|
|
||||||
ctrl_sd = torch.load(args.model_tuned, map_location='cpu')
|
ctrl_sd = torch.load(args.model_tuned, map_location='cpu')
|
||||||
@@ -65,8 +65,8 @@ def svd(args):
|
|||||||
continue
|
continue
|
||||||
ctrl_unet_sd[unet_key] = ctrl_sd[key]
|
ctrl_unet_sd[unet_key] = ctrl_sd[key]
|
||||||
|
|
||||||
unet_config = model_util.create_unet_diffusers_config(False)
|
unet_config = model_util.create_unet_diffusers_config(args.v2)
|
||||||
ctrl_unet_sd_du = model_util.convert_ldm_unet_checkpoint(False, ctrl_unet_sd, unet_config)
|
ctrl_unet_sd_du = model_util.convert_ldm_unet_checkpoint(args.v2, ctrl_unet_sd, unet_config)
|
||||||
|
|
||||||
# load weights to U-Net
|
# load weights to U-Net
|
||||||
ctrl_unet = UNet2DConditionModel(**unet_config)
|
ctrl_unet = UNet2DConditionModel(**unet_config)
|
||||||
@@ -75,8 +75,6 @@ def svd(args):
|
|||||||
|
|
||||||
# LoRAに対応する部分のU-Netの重みを読み込む #################################
|
# LoRAに対応する部分のU-Netの重みを読み込む #################################
|
||||||
|
|
||||||
org_unet_sd_du = org_unet.state_dict()
|
|
||||||
|
|
||||||
diffs = {}
|
diffs = {}
|
||||||
for (org_name, org_module), (ctrl_name, ctrl_module) in zip(org_unet.named_modules(), ctrl_unet.named_modules()):
|
for (org_name, org_module), (ctrl_name, ctrl_module) in zip(org_unet.named_modules(), ctrl_unet.named_modules()):
|
||||||
if org_module.__class__.__name__ != "Linear" and org_module.__class__.__name__ != "Conv2d":
|
if org_module.__class__.__name__ != "Linear" and org_module.__class__.__name__ != "Conv2d":
|
||||||
@@ -172,6 +170,7 @@ def svd(args):
|
|||||||
if 'zero_convs' in key or 'input_hint_block' in key or 'middle_block_out' in key:
|
if 'zero_convs' in key or 'input_hint_block' in key or 'middle_block_out' in key:
|
||||||
ctrl_lora_sd[key] = value
|
ctrl_lora_sd[key] = value
|
||||||
|
|
||||||
|
# verify state dict by loading it
|
||||||
info = lora_network.load_state_dict(ctrl_lora_sd)
|
info = lora_network.load_state_dict(ctrl_lora_sd)
|
||||||
print(f"loading control lora sd: {info}")
|
print(f"loading control lora sd: {info}")
|
||||||
|
|
||||||
@@ -183,7 +182,7 @@ def svd(args):
|
|||||||
# metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
# metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
||||||
|
|
||||||
# lora_network.save_weights(args.save_to, save_dtype, metadata)
|
# lora_network.save_weights(args.save_to, save_dtype, metadata)
|
||||||
save_file(ctrl_lora_sd, args.save_to)
|
save_to_file(args.save_to, ctrl_lora_sd, save_dtype)
|
||||||
print(f"LoRA weights are saved to: {args.save_to}")
|
print(f"LoRA weights are saved to: {args.save_to}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user