fix to support dynamic rank/alpha

This commit is contained in:
Kohya S
2023-03-21 21:59:51 +09:00
parent 4f92b6266c
commit 193674e16c

View File

@@ -208,18 +208,28 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
weight_name = None
if 'lora_down' in key:
block_down_name = key.split(".")[0]
weight_name = key.split(".")[-1]
lora_down_weight = value
if 'lora_up' in key:
block_up_name = key.split(".")[0]
lora_up_weight = value
else:
continue
# find corresponding lora_up and alpha
block_up_name = block_down_name
lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + '.alpha', None)
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
if (block_down_name == block_up_name) and weights_loaded:
if weights_loaded:
conv2d = (len(lora_down_weight.size()) == 4)
if lora_alpha is None:
scale = 1.0
else:
scale = lora_alpha/lora_down_weight.size()[0]
if conv2d:
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)