fix to support dynamic rank/alpha

This commit is contained in:
Kohya S
2023-03-21 21:59:51 +09:00
parent 4f92b6266c
commit 193674e16c

View File

@@ -208,18 +208,28 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
with torch.no_grad(): with torch.no_grad():
for key, value in tqdm(lora_sd.items()): for key, value in tqdm(lora_sd.items()):
weight_name = None
if 'lora_down' in key: if 'lora_down' in key:
block_down_name = key.split(".")[0] block_down_name = key.split(".")[0]
weight_name = key.split(".")[-1]
lora_down_weight = value lora_down_weight = value
if 'lora_up' in key: else:
block_up_name = key.split(".")[0] continue
lora_up_weight = value
# find corresponding lora_up and alpha
block_up_name = block_down_name
lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + '.alpha', None)
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
if (block_down_name == block_up_name) and weights_loaded: if weights_loaded:
conv2d = (len(lora_down_weight.size()) == 4) conv2d = (len(lora_down_weight.size()) == 4)
if lora_alpha is None:
scale = 1.0
else:
scale = lora_alpha/lora_down_weight.size()[0]
if conv2d: if conv2d:
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
@@ -329,7 +339,7 @@ def setup_parser() -> argparse.ArgumentParser:
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank") help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
parser.add_argument("--dynamic_param", type=float, default=None, parser.add_argument("--dynamic_param", type=float, default=None,
help="Specify target for dynamic reduction") help="Specify target for dynamic reduction")
return parser return parser