mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 14:45:19 +00:00
read dim/rank from weights
This commit is contained in:
@@ -1556,8 +1556,18 @@ def main(args):
|
|||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
state_dict = load_file(model_file)
|
state_dict = load_file(model_file)
|
||||||
|
lora_rank = None
|
||||||
|
emb_dim = None
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if lora_rank is None and "lora_down.weight" in key:
|
||||||
|
lora_rank = value.shape[0]
|
||||||
|
elif emb_dim is None and "conditioning1.0" in key:
|
||||||
|
emb_dim = value.shape[0]
|
||||||
|
if lora_rank is not None and emb_dim is not None:
|
||||||
|
break
|
||||||
|
assert lora_rank is not None and emb_dim is not None, f"invalid control net: {model_file}"
|
||||||
|
|
||||||
control_net = LoRAControlNet(unet, 128, 32, 1) # TODO load from weights
|
control_net = LoRAControlNet(unet, emb_dim, lora_rank, 1)
|
||||||
control_net.apply_to()
|
control_net.apply_to()
|
||||||
control_net.load_state_dict(state_dict)
|
control_net.load_state_dict(state_dict)
|
||||||
control_net.to(dtype).to(device)
|
control_net.to(dtype).to(device)
|
||||||
|
|||||||
Reference in New Issue
Block a user