From afc03af3ca0ff67cbcd7991654329ee1c311b301 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 17 Aug 2023 12:10:52 +0900 Subject: [PATCH] read dim/rank from weights --- sdxl_gen_img_lora_ctrl_test.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/sdxl_gen_img_lora_ctrl_test.py b/sdxl_gen_img_lora_ctrl_test.py index 4820aa3f..e8b22ee1 100644 --- a/sdxl_gen_img_lora_ctrl_test.py +++ b/sdxl_gen_img_lora_ctrl_test.py @@ -1556,8 +1556,18 @@ def main(args): from safetensors.torch import load_file state_dict = load_file(model_file) + lora_rank = None + emb_dim = None + for key, value in state_dict.items(): + if lora_rank is None and "lora_down.weight" in key: + lora_rank = value.shape[0] + elif emb_dim is None and "conditioning1.0" in key: + emb_dim = value.shape[0] + if lora_rank is not None and emb_dim is not None: + break + assert lora_rank is not None and emb_dim is not None, f"invalid control net: {model_file}" - control_net = LoRAControlNet(unet, 128, 32, 1) # TODO load from weights + control_net = LoRAControlNet(unet, emb_dim, lora_rank, 1) control_net.apply_to() control_net.load_state_dict(state_dict) control_net.to(dtype).to(device)