mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Support concat LoRA
This commit is contained in:
@@ -110,7 +110,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|||||||
module.weight = torch.nn.Parameter(weight)
|
module.weight = torch.nn.Parameter(weight)
|
||||||
|
|
||||||
|
|
||||||
def merge_lora_models(models, ratios, merge_dtype):
|
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||||
base_alphas = {} # alpha for merged model
|
base_alphas = {} # alpha for merged model
|
||||||
base_dims = {}
|
base_dims = {}
|
||||||
|
|
||||||
@@ -158,6 +158,12 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|||||||
for key in lora_sd.keys():
|
for key in lora_sd.keys():
|
||||||
if "alpha" in key:
|
if "alpha" in key:
|
||||||
continue
|
continue
|
||||||
|
if "lora_up" in key and concat:
|
||||||
|
concat_dim = 1
|
||||||
|
elif "lora_down" in key and concat:
|
||||||
|
concat_dim = 0
|
||||||
|
else:
|
||||||
|
concat_dim = None
|
||||||
|
|
||||||
lora_module_name = key[: key.rfind(".lora_")]
|
lora_module_name = key[: key.rfind(".lora_")]
|
||||||
|
|
||||||
@@ -165,12 +171,16 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|||||||
alpha = alphas[lora_module_name]
|
alpha = alphas[lora_module_name]
|
||||||
|
|
||||||
scale = math.sqrt(alpha / base_alpha) * ratio
|
scale = math.sqrt(alpha / base_alpha) * ratio
|
||||||
|
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
|
||||||
|
|
||||||
if key in merged_sd:
|
if key in merged_sd:
|
||||||
assert (
|
assert (
|
||||||
merged_sd[key].size() == lora_sd[key].size()
|
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
|
||||||
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
||||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
if concat_dim is not None:
|
||||||
|
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
|
||||||
|
else:
|
||||||
|
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
||||||
else:
|
else:
|
||||||
merged_sd[key] = lora_sd[key] * scale
|
merged_sd[key] = lora_sd[key] * scale
|
||||||
|
|
||||||
@@ -178,6 +188,13 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|||||||
for lora_module_name, alpha in base_alphas.items():
|
for lora_module_name, alpha in base_alphas.items():
|
||||||
key = lora_module_name + ".alpha"
|
key = lora_module_name + ".alpha"
|
||||||
merged_sd[key] = torch.tensor(alpha)
|
merged_sd[key] = torch.tensor(alpha)
|
||||||
|
if shuffle:
|
||||||
|
key_down = lora_module_name + ".lora_down.weight"
|
||||||
|
key_up = lora_module_name + ".lora_up.weight"
|
||||||
|
dim = merged_sd[key_down].shape[0]
|
||||||
|
perm = torch.randperm(dim)
|
||||||
|
merged_sd[key_down] = merged_sd[key_down][perm]
|
||||||
|
merged_sd[key_up] = merged_sd[key_up][:,perm]
|
||||||
|
|
||||||
print("merged model")
|
print("merged model")
|
||||||
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||||
@@ -256,7 +273,7 @@ def merge(args):
|
|||||||
args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
|
args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype)
|
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
||||||
|
|
||||||
print(f"calculating hashes and creating metadata...")
|
print(f"calculating hashes and creating metadata...")
|
||||||
|
|
||||||
@@ -317,7 +334,19 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
||||||
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--concat",
|
||||||
|
action="store_true",
|
||||||
|
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
|
||||||
|
+ "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--shuffle",
|
||||||
|
action="store_true",
|
||||||
|
help="shuffle lora weight./ "
|
||||||
|
+ "LoRAの重みをシャッフルする",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
|||||||
module.weight = torch.nn.Parameter(weight)
|
module.weight = torch.nn.Parameter(weight)
|
||||||
|
|
||||||
|
|
||||||
def merge_lora_models(models, ratios, merge_dtype):
|
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||||
base_alphas = {} # alpha for merged model
|
base_alphas = {} # alpha for merged model
|
||||||
base_dims = {}
|
base_dims = {}
|
||||||
|
|
||||||
@@ -161,6 +161,13 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|||||||
for key in tqdm(lora_sd.keys()):
|
for key in tqdm(lora_sd.keys()):
|
||||||
if "alpha" in key:
|
if "alpha" in key:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if "lora_up" in key and concat:
|
||||||
|
concat_dim = 1
|
||||||
|
elif "lora_down" in key and concat:
|
||||||
|
concat_dim = 0
|
||||||
|
else:
|
||||||
|
concat_dim = None
|
||||||
|
|
||||||
lora_module_name = key[: key.rfind(".lora_")]
|
lora_module_name = key[: key.rfind(".lora_")]
|
||||||
|
|
||||||
@@ -168,12 +175,16 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|||||||
alpha = alphas[lora_module_name]
|
alpha = alphas[lora_module_name]
|
||||||
|
|
||||||
scale = math.sqrt(alpha / base_alpha) * ratio
|
scale = math.sqrt(alpha / base_alpha) * ratio
|
||||||
|
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
|
||||||
|
|
||||||
if key in merged_sd:
|
if key in merged_sd:
|
||||||
assert (
|
assert (
|
||||||
merged_sd[key].size() == lora_sd[key].size()
|
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
|
||||||
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
||||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
if concat_dim is not None:
|
||||||
|
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
|
||||||
|
else:
|
||||||
|
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
||||||
else:
|
else:
|
||||||
merged_sd[key] = lora_sd[key] * scale
|
merged_sd[key] = lora_sd[key] * scale
|
||||||
|
|
||||||
@@ -181,6 +192,13 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|||||||
for lora_module_name, alpha in base_alphas.items():
|
for lora_module_name, alpha in base_alphas.items():
|
||||||
key = lora_module_name + ".alpha"
|
key = lora_module_name + ".alpha"
|
||||||
merged_sd[key] = torch.tensor(alpha)
|
merged_sd[key] = torch.tensor(alpha)
|
||||||
|
if shuffle:
|
||||||
|
key_down = lora_module_name + ".lora_down.weight"
|
||||||
|
key_up = lora_module_name + ".lora_up.weight"
|
||||||
|
dim = merged_sd[key_down].shape[0]
|
||||||
|
perm = torch.randperm(dim)
|
||||||
|
merged_sd[key_down] = merged_sd[key_down][perm]
|
||||||
|
merged_sd[key_up] = merged_sd[key_up][:,perm]
|
||||||
|
|
||||||
print("merged model")
|
print("merged model")
|
||||||
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||||
@@ -252,7 +270,7 @@ def merge(args):
|
|||||||
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
|
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype)
|
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
||||||
|
|
||||||
print(f"calculating hashes and creating metadata...")
|
print(f"calculating hashes and creating metadata...")
|
||||||
|
|
||||||
@@ -307,6 +325,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
||||||
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--concat",
|
||||||
|
action="store_true",
|
||||||
|
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
|
||||||
|
+ "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--shuffle",
|
||||||
|
action="store_true",
|
||||||
|
help="shuffle lora weight./ "
|
||||||
|
+ "LoRAの重みをシャッフルする",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user