mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge pull request #760 from Symbiomatrix/bugfix1
Update resize_lora.py
This commit is contained in:
@@ -219,16 +219,16 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
for key, value in tqdm(lora_sd.items()):
|
for key, value in tqdm(lora_sd.items()):
|
||||||
weight_name = None
|
weight_name = None
|
||||||
if 'lora_down' in key:
|
if 'lora_down' in key:
|
||||||
block_down_name = key.split(".")[0]
|
block_down_name = key.rsplit('lora_down', 1)[0]
|
||||||
weight_name = key.split(".")[-1]
|
weight_name = key.rsplit(".", 1)[-1]
|
||||||
lora_down_weight = value
|
lora_down_weight = value
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# find corresponding lora_up and alpha
|
# find corresponding lora_up and alpha
|
||||||
block_up_name = block_down_name
|
block_up_name = block_down_name
|
||||||
lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
|
lora_up_weight = lora_sd.get(block_up_name + 'lora_up.' + weight_name, None)
|
||||||
lora_alpha = lora_sd.get(block_down_name + '.alpha', None)
|
lora_alpha = lora_sd.get(block_down_name + 'alpha', None)
|
||||||
|
|
||||||
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
|
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
|
||||||
|
|
||||||
@@ -263,9 +263,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
verbose_str+=f"\n"
|
verbose_str+=f"\n"
|
||||||
|
|
||||||
new_alpha = param_dict['new_alpha']
|
new_alpha = param_dict['new_alpha']
|
||||||
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
o_lora_sd[block_down_name + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||||
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
|
o_lora_sd[block_up_name + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
|
||||||
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
|
o_lora_sd[block_up_name + "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
|
||||||
|
|
||||||
block_down_name = None
|
block_down_name = None
|
||||||
block_up_name = None
|
block_up_name = None
|
||||||
|
|||||||
Reference in New Issue
Block a user