From 9d678a6f41d0015f9bf8767ab614051e51f7405b Mon Sep 17 00:00:00 2001 From: Symbiomatrix Date: Wed, 16 Aug 2023 00:08:09 +0300 Subject: [PATCH] Update resize_lora.py --- networks/resize_lora.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 7b740634..90e5dffc 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -219,16 +219,16 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn for key, value in tqdm(lora_sd.items()): weight_name = None if 'lora_down' in key: - block_down_name = key.split(".")[0] - weight_name = key.split(".")[-1] + block_down_name = key.rsplit('lora_down', 1)[0] + weight_name = key.rsplit(".", 1)[-1] lora_down_weight = value else: continue # find corresponding lora_up and alpha block_up_name = block_down_name - lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None) - lora_alpha = lora_sd.get(block_down_name + '.alpha', None) + lora_up_weight = lora_sd.get(block_up_name + 'lora_up.' + weight_name, None) + lora_alpha = lora_sd.get(block_down_name + 'alpha', None) weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) @@ -263,9 +263,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn verbose_str+=f"\n" new_alpha = param_dict['new_alpha'] - o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype) + o_lora_sd[block_down_name + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype) block_down_name = None block_up_name = None