mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix to support dynamic rank/alpha
This commit is contained in:
@@ -208,18 +208,28 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for key, value in tqdm(lora_sd.items()):
|
for key, value in tqdm(lora_sd.items()):
|
||||||
|
weight_name = None
|
||||||
if 'lora_down' in key:
|
if 'lora_down' in key:
|
||||||
block_down_name = key.split(".")[0]
|
block_down_name = key.split(".")[0]
|
||||||
|
weight_name = key.split(".")[-1]
|
||||||
lora_down_weight = value
|
lora_down_weight = value
|
||||||
if 'lora_up' in key:
|
else:
|
||||||
block_up_name = key.split(".")[0]
|
continue
|
||||||
lora_up_weight = value
|
|
||||||
|
# find corresponding lora_up and alpha
|
||||||
|
block_up_name = block_down_name
|
||||||
|
lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
|
||||||
|
lora_alpha = lora_sd.get(block_down_name + '.alpha', None)
|
||||||
|
|
||||||
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
|
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
|
||||||
|
|
||||||
if (block_down_name == block_up_name) and weights_loaded:
|
if weights_loaded:
|
||||||
|
|
||||||
conv2d = (len(lora_down_weight.size()) == 4)
|
conv2d = (len(lora_down_weight.size()) == 4)
|
||||||
|
if lora_alpha is None:
|
||||||
|
scale = 1.0
|
||||||
|
else:
|
||||||
|
scale = lora_alpha/lora_down_weight.size()[0]
|
||||||
|
|
||||||
if conv2d:
|
if conv2d:
|
||||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||||
|
|||||||
Reference in New Issue
Block a user