read dim/rank from weights

This commit is contained in:
Kohya S
2023-08-17 12:10:52 +09:00
parent 306ee24c90
commit afc03af3ca

View File

@@ -1556,8 +1556,18 @@ def main(args):
from safetensors.torch import load_file
state_dict = load_file(model_file)
lora_rank = None
emb_dim = None
for key, value in state_dict.items():
if lora_rank is None and "lora_down.weight" in key:
lora_rank = value.shape[0]
elif emb_dim is None and "conditioning1.0" in key:
emb_dim = value.shape[0]
if lora_rank is not None and emb_dim is not None:
break
assert lora_rank is not None and emb_dim is not None, f"invalid control net: {model_file}"
control_net = LoRAControlNet(unet, 128, 32, 1) # TODO load from weights
control_net = LoRAControlNet(unet, emb_dim, lora_rank, 1)
control_net.apply_to()
control_net.load_state_dict(state_dict)
control_net.to(dtype).to(device)