diff --git a/README.md b/README.md index 165eed34..3f5c4daa 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 20, 2024 (update 2): +`flux_merge_lora.py` now supports LoRA from AI-toolkit (Diffusers based keys). Specify `--diffusers` option to merge LoRA with Diffusers based keys. Thanks to exveria1015! + Aug 20, 2024: FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution). diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index 1ba1f314..fd9cc4e3 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -4,6 +4,7 @@ import os import time import torch +from safetensors import safe_open from safetensors.torch import load_file, save_file from tqdm import tqdm @@ -45,6 +46,81 @@ def save_to_file(file_name, state_dict, dtype, metadata): def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): + # create module map without loading state_dict + logger.info(f"loading keys from FLUX.1 model: {flux_model}") + lora_name_to_module_key = {} + with safe_open(flux_model, framework="pt", device=loading_device) as flux_file: + keys = list(flux_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") + lora_name_to_module_key[lora_name] = key + + flux_state_dict = load_file(flux_model, device=loading_device) + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU + + logger.info(f"merging...") + for key in tqdm(list(lora_sd.keys())): + if "lora_down" in key: + lora_name = key[: key.rfind(".lora_down")] + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + if lora_name not in lora_name_to_module_key: + logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.") + continue + + down_weight = lora_sd.pop(key) + up_weight = lora_sd.pop(up_key) + + dim = down_weight.size()[0] + alpha = lora_sd.pop(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + module_weight_key = lora_name_to_module_key[lora_name] + if module_weight_key not in flux_state_dict: + weight = flux_file.get_tensor(module_weight_key) + else: + weight = flux_state_dict[module_weight_key] + + weight = weight.to(working_device, merge_dtype) + up_weight = up_weight.to(working_device, merge_dtype) + down_weight = down_weight.to(working_device, merge_dtype) + + # logger.info(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype) + del up_weight + del down_weight + del weight + + if len(lora_sd) > 0: + logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}") + + return flux_state_dict + + +def merge_to_flux_model_diffusers(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): logger.info(f"loading keys from FLUX.1 model: {flux_model}") flux_state_dict = load_file(flux_model, device=loading_device) @@ -422,15 +498,14 @@ def merge(args): os.makedirs(dest_dir) if args.flux_model is not None: - state_dict = merge_to_flux_model( - args.loading_device, - args.working_device, - args.flux_model, - args.models, - args.ratios, - merge_dtype, - save_dtype, - ) + if not args.diffusers: + state_dict = merge_to_flux_model( + args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + ) + else: + state_dict = merge_to_flux_model_diffusers( + args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + ) if args.no_metadata: sai_metadata = None @@ -438,16 +513,7 @@ def merge(args): merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - None, - False, - False, - False, - False, - False, - time.time(), - title=title, - merged_from=merged_from, - flux="dev", + None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev" ) logger.info(f"saving FLUX model to: {args.save_to}") @@ -466,16 +532,7 @@ def merge(args): merged_from = sai_model_spec.build_merged_from(args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - state_dict, - False, - False, - False, - True, - False, - time.time(), - title=title, - merged_from=merged_from, - flux="dev", + state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" ) metadata.update(sai_metadata) @@ -553,6 +610,11 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="shuffle lora weight./ " + "LoRAの重みをシャッフルする", ) + parser.add_argument( + "--diffusers", + action="store_true", + help="merge Diffusers (?) LoRA models / Diffusers (?) LoRAモデルをマージする", + ) return parser