mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Support resizing ControlLoRA
This commit is contained in:
@@ -24,6 +24,7 @@ MIN_SV = 1e-6
|
||||
LORAFMT1 = ["lora_down", "lora_up"]
|
||||
LORAFMT2 = ["lora.down", "lora.up"]
|
||||
LORAFMT3 = ["lora_A", "lora_B"]
|
||||
LORAFMT4 = ["down", "up"]
|
||||
LORAFMT = LORAFMT1
|
||||
|
||||
# Model save and load functions
|
||||
@@ -209,13 +210,15 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
if network_alpha is None and "alpha" in key:
|
||||
network_alpha = value
|
||||
if (network_dim is None and len(value.size()) == 2
|
||||
and (LORAFMT1[0] in key or LORAFMT2[0] in key or LORAFMT3[0] in key)):
|
||||
and (LORAFMT1[0] in key or LORAFMT2[0] in key or LORAFMT3[0] in key or LORAFMT4[0] in key)):
|
||||
if LORAFMT1[0] in key:
|
||||
LORAFMT = LORAFMT1
|
||||
elif LORAFMT2[0] in key:
|
||||
LORAFMT = LORAFMT2
|
||||
elif LORAFMT3[0] in key:
|
||||
LORAFMT = LORAFMT3
|
||||
elif LORAFMT4[0] in key:
|
||||
LORAFMT = LORAFMT4
|
||||
network_dim = value.size()[0]
|
||||
if network_alpha is not None and network_dim is not None:
|
||||
break
|
||||
@@ -241,14 +244,17 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
weight_name = None
|
||||
if LORAFMT[0] in key:
|
||||
block_down_name = key.rsplit(f".{LORAFMT[0]}", 1)[0]
|
||||
weight_name = key.rsplit(".", 1)[-1]
|
||||
if key.endswith(f".{LORAFMT[0]}"):
|
||||
weight_name = ""
|
||||
else:
|
||||
weight_name = key.rsplit(f".{LORAFMT[0]}", 1)[-1]
|
||||
lora_down_weight = value
|
||||
else:
|
||||
continue
|
||||
|
||||
# find corresponding lora_up and alpha
|
||||
block_up_name = block_down_name
|
||||
lora_up_weight = lora_sd.get(block_up_name + f".{LORAFMT[1]}." + weight_name, None)
|
||||
lora_up_weight = lora_sd.get(block_up_name + f".{LORAFMT[1]}" + weight_name, None)
|
||||
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
|
||||
|
||||
weights_loaded = lora_down_weight is not None and lora_up_weight is not None
|
||||
@@ -286,8 +292,8 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
verbose_str += "\n"
|
||||
|
||||
new_alpha = param_dict["new_alpha"]
|
||||
o_lora_sd[block_down_name + f".{LORAFMT[0]}.weight"] = param_dict[LORAFMT[0]].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + f".{LORAFMT[1]}.weight"] = param_dict[LORAFMT[1]].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_down_name + f".{LORAFMT[0]}" + weight_name] = param_dict[LORAFMT[0]].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + f".{LORAFMT[1]}" + weight_name] = param_dict[LORAFMT[1]].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
|
||||
|
||||
block_down_name = None
|
||||
|
||||
Reference in New Issue
Block a user