From 68e0767404d88911079a052a2c40762f98ae58e1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 12 Apr 2023 23:40:10 +0900 Subject: [PATCH] add comment about scaling --- networks/extract_lora_from_dylora.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/networks/extract_lora_from_dylora.py b/networks/extract_lora_from_dylora.py index 0037636f..9ae4056e 100644 --- a/networks/extract_lora_from_dylora.py +++ b/networks/extract_lora_from_dylora.py @@ -3,6 +3,7 @@ # Thanks to cloneofsimo import argparse +import math import os import torch from safetensors.torch import load_file, save_file, safe_open @@ -43,6 +44,7 @@ def split_lora_model(lora_sd, unit): rank = unit splitted_models = [] + new_alpha = None while rank < max_rank: print(f"Splitting rank {rank}") new_sd = {} @@ -52,9 +54,15 @@ def split_lora_model(lora_sd, unit): elif "lora_up" in key: new_sd[key] = value[:, :rank].contiguous() else: - new_sd[key] = value # alpha and other parameters + # なぜかscaleするとおかしくなる…… + # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0] + # scale = math.sqrt(this_rank / rank) # rank is > unit + # print(key, value.size(), this_rank, rank, value, scale) + # new_alpha = value * scale # always same + # new_sd[key] = new_alpha + new_sd[key] = value - splitted_models.append((new_sd, rank)) + splitted_models.append((new_sd, rank, new_alpha)) rank += unit return max_rank, splitted_models @@ -68,7 +76,7 @@ def split(args): original_rank, splitted_models = split_lora_model(lora_sd, args.unit) comment = metadata.get("ss_training_comment", "") - for state_dict, new_rank in splitted_models: + for state_dict, new_rank, new_alpha in splitted_models: # update metadata if metadata is None: new_metadata = {} @@ -77,6 +85,7 @@ def split(args): new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}" new_metadata["ss_network_dim"] = str(new_rank) + # new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy()) model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash