diff --git a/sdxl_gen_img_lora_ctrl_test.py b/sdxl_gen_img_lora_ctrl_test.py index 4820aa3f..e8b22ee1 100644 --- a/sdxl_gen_img_lora_ctrl_test.py +++ b/sdxl_gen_img_lora_ctrl_test.py @@ -1556,8 +1556,18 @@ def main(args): from safetensors.torch import load_file state_dict = load_file(model_file) + lora_rank = None + emb_dim = None + for key, value in state_dict.items(): + if lora_rank is None and "lora_down.weight" in key: + lora_rank = value.shape[0] + elif emb_dim is None and "conditioning1.0" in key: + emb_dim = value.shape[0] + if lora_rank is not None and emb_dim is not None: + break + assert lora_rank is not None and emb_dim is not None, f"invalid control net: {model_file}" - control_net = LoRAControlNet(unet, 128, 32, 1) # TODO load from weights + control_net = LoRAControlNet(unet, emb_dim, lora_rank, 1) control_net.apply_to() control_net.load_state_dict(state_dict) control_net.to(dtype).to(device)