diff --git a/networks/extract_control_net_lora.py b/networks/extract_control_net_lora.py index d2c460ba..1c25b486 100644 --- a/networks/extract_control_net_lora.py +++ b/networks/extract_control_net_lora.py @@ -19,16 +19,16 @@ CLAMP_QUANTILE = 0.99 MIN_DIFF = 1e-6 -def save_to_file(file_name, model, state_dict, dtype): +def save_to_file(file_name, state_dict, dtype): if dtype is not None: for key in list(state_dict.keys()): if type(state_dict[key]) == torch.Tensor: state_dict[key] = state_dict[key].to(dtype) if os.path.splitext(file_name)[1] == '.safetensors': - save_file(model, file_name) + save_file(state_dict, file_name) else: - torch.save(model, file_name) + torch.save(state_dict, file_name) def svd(args): @@ -45,15 +45,15 @@ def svd(args): # Diffusersのキーに変換するため、original sdとcontrol sdからU-Netに重みを読み込む ############### - # original sdをDiffusersに読み込む + # original sdをDiffusersのU-Netに読み込む print(f"loading original SD model : {args.model_org}") - org_text_encoder, _, org_unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) + _, _, org_unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) org_sd = torch.load(args.model_org, map_location='cpu') if 'state_dict' in org_sd: org_sd = org_sd['state_dict'] - # control sdからキー変換しつつU-Netに対応する部分のみ取り出す + # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのuU-Netに読み込む print(f"loading control SD model : {args.model_tuned}") ctrl_sd = torch.load(args.model_tuned, map_location='cpu') @@ -65,8 +65,8 @@ def svd(args): continue ctrl_unet_sd[unet_key] = ctrl_sd[key] - unet_config = model_util.create_unet_diffusers_config(False) - ctrl_unet_sd_du = model_util.convert_ldm_unet_checkpoint(False, ctrl_unet_sd, unet_config) + unet_config = model_util.create_unet_diffusers_config(args.v2) + ctrl_unet_sd_du = model_util.convert_ldm_unet_checkpoint(args.v2, ctrl_unet_sd, unet_config) # load weights to U-Net ctrl_unet = UNet2DConditionModel(**unet_config) @@ -75,8 +75,6 @@ def svd(args): # LoRAに対応する部分のU-Netの重みを読み込む ################################# - org_unet_sd_du = org_unet.state_dict() - diffs = {} for (org_name, org_module), (ctrl_name, ctrl_module) in zip(org_unet.named_modules(), ctrl_unet.named_modules()): if org_module.__class__.__name__ != "Linear" and org_module.__class__.__name__ != "Conv2d": @@ -172,6 +170,7 @@ def svd(args): if 'zero_convs' in key or 'input_hint_block' in key or 'middle_block_out' in key: ctrl_lora_sd[key] = value + # verify state dict by loading it info = lora_network.load_state_dict(ctrl_lora_sd) print(f"loading control lora sd: {info}") @@ -183,7 +182,7 @@ def svd(args): # metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} # lora_network.save_weights(args.save_to, save_dtype, metadata) - save_file(ctrl_lora_sd, args.save_to) + save_to_file(args.save_to, ctrl_lora_sd, save_dtype) print(f"LoRA weights are saved to: {args.save_to}")