mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add dtype
This commit is contained in:
@@ -19,16 +19,16 @@ CLAMP_QUANTILE = 0.99
|
||||
MIN_DIFF = 1e-6
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype):
|
||||
def save_to_file(file_name, state_dict, dtype):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
save_file(model, file_name)
|
||||
save_file(state_dict, file_name)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
torch.save(state_dict, file_name)
|
||||
|
||||
|
||||
def svd(args):
|
||||
@@ -45,15 +45,15 @@ def svd(args):
|
||||
|
||||
# Diffusersのキーに変換するため、original sdとcontrol sdからU-Netに重みを読み込む ###############
|
||||
|
||||
# original sdをDiffusersに読み込む
|
||||
# original sdをDiffusersのU-Netに読み込む
|
||||
print(f"loading original SD model : {args.model_org}")
|
||||
org_text_encoder, _, org_unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
|
||||
_, _, org_unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
|
||||
|
||||
org_sd = torch.load(args.model_org, map_location='cpu')
|
||||
if 'state_dict' in org_sd:
|
||||
org_sd = org_sd['state_dict']
|
||||
|
||||
# control sdからキー変換しつつU-Netに対応する部分のみ取り出す
|
||||
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのuU-Netに読み込む
|
||||
print(f"loading control SD model : {args.model_tuned}")
|
||||
|
||||
ctrl_sd = torch.load(args.model_tuned, map_location='cpu')
|
||||
@@ -65,8 +65,8 @@ def svd(args):
|
||||
continue
|
||||
ctrl_unet_sd[unet_key] = ctrl_sd[key]
|
||||
|
||||
unet_config = model_util.create_unet_diffusers_config(False)
|
||||
ctrl_unet_sd_du = model_util.convert_ldm_unet_checkpoint(False, ctrl_unet_sd, unet_config)
|
||||
unet_config = model_util.create_unet_diffusers_config(args.v2)
|
||||
ctrl_unet_sd_du = model_util.convert_ldm_unet_checkpoint(args.v2, ctrl_unet_sd, unet_config)
|
||||
|
||||
# load weights to U-Net
|
||||
ctrl_unet = UNet2DConditionModel(**unet_config)
|
||||
@@ -75,8 +75,6 @@ def svd(args):
|
||||
|
||||
# LoRAに対応する部分のU-Netの重みを読み込む #################################
|
||||
|
||||
org_unet_sd_du = org_unet.state_dict()
|
||||
|
||||
diffs = {}
|
||||
for (org_name, org_module), (ctrl_name, ctrl_module) in zip(org_unet.named_modules(), ctrl_unet.named_modules()):
|
||||
if org_module.__class__.__name__ != "Linear" and org_module.__class__.__name__ != "Conv2d":
|
||||
@@ -172,6 +170,7 @@ def svd(args):
|
||||
if 'zero_convs' in key or 'input_hint_block' in key or 'middle_block_out' in key:
|
||||
ctrl_lora_sd[key] = value
|
||||
|
||||
# verify state dict by loading it
|
||||
info = lora_network.load_state_dict(ctrl_lora_sd)
|
||||
print(f"loading control lora sd: {info}")
|
||||
|
||||
@@ -183,7 +182,7 @@ def svd(args):
|
||||
# metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
||||
|
||||
# lora_network.save_weights(args.save_to, save_dtype, metadata)
|
||||
save_file(ctrl_lora_sd, args.save_to)
|
||||
save_to_file(args.save_to, ctrl_lora_sd, save_dtype)
|
||||
print(f"LoRA weights are saved to: {args.save_to}")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user