diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index c3986ef1..df0ba606 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -1,13 +1,14 @@ -import math import argparse +import math import os import time + import torch -from safetensors import safe_open from safetensors.torch import load_file, save_file from tqdm import tqdm + +import lora_flux as lora_flux from library import sai_model_spec, train_util -import networks.lora_flux as lora_flux from library.utils import setup_logging setup_logging() @@ -42,34 +43,181 @@ def save_to_file(file_name, state_dict, dtype, metadata): save_file(state_dict, file_name, metadata=metadata) -def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): - # create module map without loading state_dict +def merge_to_flux_model( + loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype +): logger.info(f"loading keys from FLUX.1 model: {flux_model}") - lora_name_to_module_key = {} - with safe_open(flux_model, framework="pt", device=loading_device) as flux_file: - keys = list(flux_file.keys()) - for key in keys: - if key.endswith(".weight"): - module_name = ".".join(key.split(".")[:-1]) - lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") - lora_name_to_module_key[lora_name] = key - flux_state_dict = load_file(flux_model, device=loading_device) + + def create_key_map(n_double_layers, n_single_layers, hidden_size): + key_map = {} + for index in range(n_double_layers): + prefix_from = f"transformer_blocks.{index}" + prefix_to = f"double_blocks.{index}" + + for end in ("weight", "bias"): + k = f"{prefix_from}.attn." + qkv_img = f"{prefix_to}.img_attn.qkv.{end}" + qkv_txt = f"{prefix_to}.txt_attn.qkv.{end}" + + key_map[f"{k}to_q.{end}"] = (qkv_img, (0, 0, hidden_size)) + key_map[f"{k}to_k.{end}"] = (qkv_img, (0, hidden_size, hidden_size)) + key_map[f"{k}to_v.{end}"] = (qkv_img, (0, hidden_size * 2, hidden_size)) + key_map[f"{k}add_q_proj.{end}"] = (qkv_txt, (0, 0, hidden_size)) + key_map[f"{k}add_k_proj.{end}"] = ( + qkv_txt, + (0, hidden_size, hidden_size), + ) + key_map[f"{k}add_v_proj.{end}"] = ( + qkv_txt, + (0, hidden_size * 2, hidden_size), + ) + + block_map = { + "attn.to_out.0.weight": "img_attn.proj.weight", + "attn.to_out.0.bias": "img_attn.proj.bias", + "norm1.linear.weight": "img_mod.lin.weight", + "norm1.linear.bias": "img_mod.lin.bias", + "norm1_context.linear.weight": "txt_mod.lin.weight", + "norm1_context.linear.bias": "txt_mod.lin.bias", + "attn.to_add_out.weight": "txt_attn.proj.weight", + "attn.to_add_out.bias": "txt_attn.proj.bias", + "ff.net.0.proj.weight": "img_mlp.0.weight", + "ff.net.0.proj.bias": "img_mlp.0.bias", + "ff.net.2.weight": "img_mlp.2.weight", + "ff.net.2.bias": "img_mlp.2.bias", + "ff_context.net.0.proj.weight": "txt_mlp.0.weight", + "ff_context.net.0.proj.bias": "txt_mlp.0.bias", + "ff_context.net.2.weight": "txt_mlp.2.weight", + "ff_context.net.2.bias": "txt_mlp.2.bias", + "attn.norm_q.weight": "img_attn.norm.query_norm.scale", + "attn.norm_k.weight": "img_attn.norm.key_norm.scale", + "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", + "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale", + } + + for k, v in block_map.items(): + key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}" + + for index in range(n_single_layers): + prefix_from = f"single_transformer_blocks.{index}" + prefix_to = f"single_blocks.{index}" + + for end in ("weight", "bias"): + k = f"{prefix_from}.attn." + qkv = f"{prefix_to}.linear1.{end}" + key_map[f"{k}to_q.{end}"] = (qkv, (0, 0, hidden_size)) + key_map[f"{k}to_k.{end}"] = (qkv, (0, hidden_size, hidden_size)) + key_map[f"{k}to_v.{end}"] = (qkv, (0, hidden_size * 2, hidden_size)) + key_map[f"{prefix_from}.proj_mlp.{end}"] = ( + qkv, + (0, hidden_size * 3, hidden_size * 4), + ) + + block_map = { + "norm.linear.weight": "modulation.lin.weight", + "norm.linear.bias": "modulation.lin.bias", + "proj_out.weight": "linear2.weight", + "proj_out.bias": "linear2.bias", + "attn.norm_q.weight": "norm.query_norm.scale", + "attn.norm_k.weight": "norm.key_norm.scale", + } + + for k, v in block_map.items(): + key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}" + + return key_map + + key_map = create_key_map( + 18, 1, 2048 + ) # Assuming 18 double layers, 1 single layer, and hidden size of 2048 + + def find_matching_key(flux_dict, lora_key): + lora_key = lora_key.replace("diffusion_model.", "") + lora_key = lora_key.replace("transformer.", "") + lora_key = lora_key.replace("lora_A", "lora_down").replace("lora_B", "lora_up") + lora_key = lora_key.replace("single_transformer_blocks", "single_blocks") + lora_key = lora_key.replace("transformer_blocks", "double_blocks") + + double_block_map = { + "attn.to_out.0": "img_attn.proj", + "norm1.linear": "img_mod.lin", + "norm1_context.linear": "txt_mod.lin", + "attn.to_add_out": "txt_attn.proj", + "ff.net.0.proj": "img_mlp.0", + "ff.net.2": "img_mlp.2", + "ff_context.net.0.proj": "txt_mlp.0", + "ff_context.net.2": "txt_mlp.2", + "attn.norm_q": "img_attn.norm.query_norm", + "attn.norm_k": "img_attn.norm.key_norm", + "attn.norm_added_q": "txt_attn.norm.query_norm", + "attn.norm_added_k": "txt_attn.norm.key_norm", + "attn.to_q": "img_attn.qkv", + "attn.to_k": "img_attn.qkv", + "attn.to_v": "img_attn.qkv", + "attn.add_q_proj": "txt_attn.qkv", + "attn.add_k_proj": "txt_attn.qkv", + "attn.add_v_proj": "txt_attn.qkv", + } + + single_block_map = { + "norm.linear": "modulation.lin", + "proj_out": "linear2", + "attn.norm_q": "norm.query_norm", + "attn.norm_k": "norm.key_norm", + "attn.to_q": "linear1", + "attn.to_k": "linear1", + "attn.to_v": "linear1", + } + + for old, new in double_block_map.items(): + lora_key = lora_key.replace(old, new) + + for old, new in single_block_map.items(): + lora_key = lora_key.replace(old, new) + + if lora_key in key_map: + flux_key = key_map[lora_key] + if isinstance(flux_key, tuple): + flux_key = flux_key[0] + logger.info(f"Found matching key: {flux_key}") + return flux_key + + # If not found in key_map, try partial matching + potential_key = lora_key + ".weight" + logger.info(f"Searching for key: {potential_key}") + matches = [k for k in flux_dict.keys() if potential_key in k] + if matches: + logger.info(f"Found matching key: {matches[0]}") + return matches[0] + return None + + merged_keys = set() for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") - lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU + lora_sd, _ = load_state_dict(model, merge_dtype) - logger.info(f"merging...") + logger.info("merging...") for key in tqdm(lora_sd.keys()): - if "lora_down" in key: - lora_name = key[: key.rfind(".lora_down")] - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" + if "lora_down" in key or "lora_A" in key: + lora_name = key[ + : key.rfind(".lora_down" if "lora_down" in key else ".lora_A") + ] + up_key = key.replace("lora_down", "lora_up").replace("lora_A", "lora_B") + alpha_key = ( + key[: key.index("lora_down" if "lora_down" in key else "lora_A")] + + "alpha" + ) - if lora_name not in lora_name_to_module_key: - logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.") + logger.info(f"Processing LoRA key: {lora_name}") + flux_key = find_matching_key(flux_state_dict, lora_name) + + if flux_key is None: + logger.warning(f"no module found for LoRA weight: {key}") continue + logger.info(f"Merging LoRA key {lora_name} into Flux key {flux_key}") + down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -77,40 +225,74 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati alpha = lora_sd.get(alpha_key, dim) scale = alpha / dim - # W <- W + U * D - module_weight_key = lora_name_to_module_key[lora_name] - if module_weight_key not in flux_state_dict: - weight = flux_file.get_tensor(module_weight_key) - else: - weight = flux_state_dict[module_weight_key] + weight = flux_state_dict[flux_key] weight = weight.to(working_device, merge_dtype) up_weight = up_weight.to(working_device, merge_dtype) down_weight = down_weight.to(working_device, merge_dtype) - # logger.info(module_name, down_weight.size(), up_weight.size()) - if len(weight.size()) == 2: - # linear - weight = weight + ratio * (up_weight @ down_weight) * scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + ratio - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * scale - ) - else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + ratio * conved * scale + if lora_name.startswith("transformer."): + if "qkv" in flux_key: + hidden_size = weight.size(-1) // 3 + update = ratio * (up_weight @ down_weight) * scale - flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype) + if "img_attn" in flux_key or "txt_attn" in flux_key: + q, k, v = torch.chunk(weight, 3, dim=-1) + if "to_q" in lora_name or "add_q_proj" in lora_name: + q += update.reshape(q.shape) + elif "to_k" in lora_name or "add_k_proj" in lora_name: + k += update.reshape(k.shape) + elif "to_v" in lora_name or "add_v_proj" in lora_name: + v += update.reshape(v.shape) + weight = torch.cat([q, k, v], dim=-1) + else: + if len(weight.size()) == 2: + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + weight = ( + weight + + ratio + * ( + up_weight.squeeze(3).squeeze(2) + @ down_weight.squeeze(3).squeeze(2) + ) + .unsqueeze(2) + .unsqueeze(3) + * scale + ) + else: + conved = torch.nn.functional.conv2d( + down_weight.permute(1, 0, 2, 3), up_weight + ).permute(1, 0, 2, 3) + weight = weight + ratio * conved * scale + else: + if len(weight.size()) == 2: + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + weight = ( + weight + + ratio + * ( + up_weight.squeeze(3).squeeze(2) + @ down_weight.squeeze(3).squeeze(2) + ) + .unsqueeze(2) + .unsqueeze(3) + * scale + ) + else: + conved = torch.nn.functional.conv2d( + down_weight.permute(1, 0, 2, 3), up_weight + ).permute(1, 0, 2, 3) + weight = weight + ratio * conved * scale + + flux_state_dict[flux_key] = weight.to(loading_device, save_dtype) + merged_keys.add(flux_key) del up_weight del down_weight del weight + logger.info(f"Merged keys: {sorted(list(merged_keys))}") return flux_state_dict @@ -126,7 +308,9 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_metadata is not None: if base_model is None: - base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + base_model = lora_metadata.get( + train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None + ) # get alpha and dim alphas = {} # alpha for current model @@ -152,10 +336,12 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + logger.info( + f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}" + ) # merge - logger.info(f"merging...") + logger.info("merging...") for key in tqdm(lora_sd.keys()): if "alpha" in key: continue @@ -173,14 +359,19 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio - scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + scale = ( + abs(scale) if "lora_up" in key else scale + ) # マイナスの重みに対応する。 if key in merged_sd: assert ( - merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None - ), f"weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。" + merged_sd[key].size() == lora_sd[key].size() + or concat_dim is not None + ), "weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。" if concat_dim is not None: - merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) + merged_sd[key] = torch.cat( + [merged_sd[key], lora_sd[key] * scale], dim=concat_dim + ) else: merged_sd[key] = merged_sd[key] + lora_sd[key] * scale else: @@ -199,7 +390,9 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): merged_sd[key_up] = merged_sd[key_up][:, perm] logger.info("merged model") - logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + logger.info( + f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}" + ) # check all dims are same dims_list = list(set(base_dims.values())) @@ -218,15 +411,17 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): # build minimum metadata dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" - metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora", dims, alphas, None) + metadata = train_util.build_minimum_network_metadata( + str(False), base_model, "networks.lora", dims, alphas, None + ) return merged_sd, metadata def merge(args): - assert len(args.models) == len( - args.ratios - ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert ( + len(args.models) == len(args.ratios) + ), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" def str_to_dtype(p): if p == "float": @@ -249,27 +444,48 @@ def merge(args): if args.flux_model is not None: state_dict = merge_to_flux_model( - args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + args.loading_device, + args.working_device, + args.flux_model, + args.models, + args.ratios, + merge_dtype, + save_dtype, ) if args.no_metadata: sai_metadata = None else: - merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) + merged_from = sai_model_spec.build_merged_from( + [args.flux_model] + args.models + ) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev" + None, + False, + False, + False, + False, + False, + time.time(), + title=title, + merged_from=merged_from, + flux="dev", ) logger.info(f"saving FLUX model to: {args.save_to}") save_to_file(args.save_to, state_dict, save_dtype, sai_metadata) else: - state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + state_dict, metadata = merge_lora_models( + args.models, args.ratios, merge_dtype, args.concat, args.shuffle + ) - logger.info(f"calculating hashes and creating metadata...") + logger.info("calculating hashes and creating metadata...") - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes( + state_dict, metadata + ) metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash @@ -277,7 +493,16 @@ def merge(args): merged_from = sai_model_spec.build_merged_from(args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" + state_dict, + False, + False, + False, + True, + False, + time.time(), + title=title, + merged_from=merged_from, + flux="dev", ) metadata.update(sai_metadata) @@ -332,7 +557,12 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="LoRA models to merge: safetensors file / マージするLoRAモデル、safetensorsファイル", ) - parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument( + "--ratios", + type=float, + nargs="*", + help="ratios for each model / それぞれのLoRAモデルの比率", + ) parser.add_argument( "--no_metadata", action="store_true",