mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
read dim/rank from weights
This commit is contained in:
@@ -1556,8 +1556,18 @@ def main(args):
|
||||
from safetensors.torch import load_file
|
||||
|
||||
state_dict = load_file(model_file)
|
||||
lora_rank = None
|
||||
emb_dim = None
|
||||
for key, value in state_dict.items():
|
||||
if lora_rank is None and "lora_down.weight" in key:
|
||||
lora_rank = value.shape[0]
|
||||
elif emb_dim is None and "conditioning1.0" in key:
|
||||
emb_dim = value.shape[0]
|
||||
if lora_rank is not None and emb_dim is not None:
|
||||
break
|
||||
assert lora_rank is not None and emb_dim is not None, f"invalid control net: {model_file}"
|
||||
|
||||
control_net = LoRAControlNet(unet, 128, 32, 1) # TODO load from weights
|
||||
control_net = LoRAControlNet(unet, emb_dim, lora_rank, 1)
|
||||
control_net.apply_to()
|
||||
control_net.load_state_dict(state_dict)
|
||||
control_net.to(dtype).to(device)
|
||||
|
||||
Reference in New Issue
Block a user