From 193674e16cca54c1b78f2ffb06477aed9561b4b8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 21 Mar 2023 21:59:51 +0900 Subject: [PATCH] fix to support dynamic rank/alpha --- networks/resize_lora.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 15769a3f..2bd86599 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -208,18 +208,28 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn with torch.no_grad(): for key, value in tqdm(lora_sd.items()): + weight_name = None if 'lora_down' in key: block_down_name = key.split(".")[0] + weight_name = key.split(".")[-1] lora_down_weight = value - if 'lora_up' in key: - block_up_name = key.split(".")[0] - lora_up_weight = value + else: + continue + + # find corresponding lora_up and alpha + block_up_name = block_down_name + lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None) + lora_alpha = lora_sd.get(block_down_name + '.alpha', None) weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) - if (block_down_name == block_up_name) and weights_loaded: + if weights_loaded: conv2d = (len(lora_down_weight.size()) == 4) + if lora_alpha is None: + scale = 1.0 + else: + scale = lora_alpha/lora_down_weight.size()[0] if conv2d: full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) @@ -329,7 +339,7 @@ def setup_parser() -> argparse.ArgumentParser: help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank") parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction") - + return parser