diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 49b71481..d8a37da2 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -24,6 +24,7 @@ MIN_SV = 1e-6 LORAFMT1 = ["lora_down", "lora_up"] LORAFMT2 = ["lora.down", "lora.up"] LORAFMT3 = ["lora_A", "lora_B"] +LORAFMT4 = ["down", "up"] LORAFMT = LORAFMT1 # Model save and load functions @@ -209,13 +210,15 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna if network_alpha is None and "alpha" in key: network_alpha = value if (network_dim is None and len(value.size()) == 2 - and (LORAFMT1[0] in key or LORAFMT2[0] in key or LORAFMT3[0] in key)): + and (LORAFMT1[0] in key or LORAFMT2[0] in key or LORAFMT3[0] in key or LORAFMT4[0] in key)): if LORAFMT1[0] in key: LORAFMT = LORAFMT1 elif LORAFMT2[0] in key: LORAFMT = LORAFMT2 elif LORAFMT3[0] in key: LORAFMT = LORAFMT3 + elif LORAFMT4[0] in key: + LORAFMT = LORAFMT4 network_dim = value.size()[0] if network_alpha is not None and network_dim is not None: break @@ -241,14 +244,17 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna weight_name = None if LORAFMT[0] in key: block_down_name = key.rsplit(f".{LORAFMT[0]}", 1)[0] - weight_name = key.rsplit(".", 1)[-1] + if key.endswith(f".{LORAFMT[0]}"): + weight_name = "" + else: + weight_name = key.rsplit(f".{LORAFMT[0]}", 1)[-1] lora_down_weight = value else: continue # find corresponding lora_up and alpha block_up_name = block_down_name - lora_up_weight = lora_sd.get(block_up_name + f".{LORAFMT[1]}." + weight_name, None) + lora_up_weight = lora_sd.get(block_up_name + f".{LORAFMT[1]}" + weight_name, None) lora_alpha = lora_sd.get(block_down_name + ".alpha", None) weights_loaded = lora_down_weight is not None and lora_up_weight is not None @@ -286,8 +292,8 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna verbose_str += "\n" new_alpha = param_dict["new_alpha"] - o_lora_sd[block_down_name + f".{LORAFMT[0]}.weight"] = param_dict[LORAFMT[0]].to(save_dtype).contiguous() - o_lora_sd[block_up_name + f".{LORAFMT[1]}.weight"] = param_dict[LORAFMT[1]].to(save_dtype).contiguous() + o_lora_sd[block_down_name + f".{LORAFMT[0]}" + weight_name] = param_dict[LORAFMT[0]].to(save_dtype).contiguous() + o_lora_sd[block_up_name + f".{LORAFMT[1]}" + weight_name] = param_dict[LORAFMT[1]].to(save_dtype).contiguous() o_lora_sd[block_up_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype) block_down_name = None