From 1d7118a62268f12ebfd81c10db53bd85ef9d7631 Mon Sep 17 00:00:00 2001 From: Maru-mee <151493593+Maru-mee@users.noreply.github.com> Date: Fri, 13 Sep 2024 19:01:36 +0900 Subject: [PATCH 1/3] Support : OFT merge to base model (#1580) * Support : OFT merge to base model * Fix typo * Fix typo_2 * Delete unused parameter 'eye' --- networks/sdxl_merge_lora.py | 190 +++++++++++++++++++++++++++--------- 1 file changed, 143 insertions(+), 47 deletions(-) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 3383a80d..2c998c8c 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -8,10 +8,12 @@ from tqdm import tqdm from library import sai_model_spec, sdxl_model_util, train_util import library.model_util as model_util import lora +import oft from library.utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) +import concurrent.futures def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -39,82 +41,176 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): else: torch.save(model, file_name) +def detect_method_from_training_model(models, dtype): + for model in models: + lora_sd, _ = load_state_dict(model, dtype) + for key in tqdm(lora_sd.keys()): + if 'lora_up' in key or 'lora_down' in key: + return 'LoRA' + elif "oft_blocks" in key: + return 'OFT' def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): text_encoder1.to(merge_dtype) text_encoder1.to(merge_dtype) unet.to(merge_dtype) + + # detect the method: OFT or LoRA_module + method = detect_method_from_training_model(models, merge_dtype) + logger.info(f"method:{method}") # create module map name_to_module = {} for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): - if i <= 1: - if i == 0: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 + if method == 'LoRA': + if i <= 1: + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 + else: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE else: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 - target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - else: - prefix = lora.LoRANetwork.LORA_PREFIX_UNET - target_replace_modules = ( + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = ( lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + ) + elif method == 'OFT': + prefix = oft.OFTNetwork.OFT_PREFIX_UNET + target_replace_modules = ( + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 ) for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - name_to_module[lora_name] = child_module - + if method == 'LoRA': + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + name_to_module[lora_name] = child_module + elif method == 'OFT': + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + oft_name = prefix + "." + name + "." + child_name + oft_name = oft_name.replace(".", "_") + name_to_module[oft_name] = child_module + + for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) logger.info(f"merging...") - for key in tqdm(lora_sd.keys()): - if "lora_down" in key: - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" - # find original module for this lora - module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if method == 'LoRA': + for key in tqdm(lora_sd.keys()): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + # find original module for this lora + module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + logger.info(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # logger.info(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + weight = module.weight + # logger.info(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + module.weight = torch.nn.Parameter(weight) + + + elif method == 'OFT': + + multiplier=1.0 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + for key in tqdm(lora_sd.keys()): + if "oft_blocks" in key: + oft_blocks = lora_sd[key] + dim = oft_blocks.shape[0] + break + for key in tqdm(lora_sd.keys()): + if "alpha" in key: + oft_blocks = lora_sd[key] + alpha = oft_blocks.item() + break + + def merge_to(key): + if "alpha" in key: + return + + # find original module for this OFT + module_name = ".".join(key.split(".")[:-1]) if module_name not in name_to_module: - logger.info(f"no module found for LoRA weight: {key}") - continue + return module = name_to_module[module_name] + # logger.info(f"apply {key} to {module}") + + oft_blocks = lora_sd[key] + + if isinstance(module, torch.nn.Linear): + out_dim = module.out_features + elif isinstance(module, torch.nn.Conv2d): + out_dim = module.out_channels + + num_blocks = dim + block_size = out_dim // dim + constraint = (0 if alpha is None else alpha) * out_dim + + block_Q = oft_blocks - oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1) + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + block_R_weighted = multiplier * block_R + (1 - multiplier) * I + R = torch.block_diag(*block_R_weighted) + + # get org weight + org_sd = module.state_dict() + org_weight = org_sd["weight"].to(device) - down_weight = lora_sd[key] - up_weight = lora_sd[up_key] - - dim = down_weight.size()[0] - alpha = lora_sd.get(alpha_key, dim) - scale = alpha / dim - - # W <- W + U * D - weight = module.weight - # logger.info(module_name, down_weight.size(), up_weight.size()) - if len(weight.size()) == 2: - # linear - weight = weight + ratio * (up_weight @ down_weight) * scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + ratio - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * scale - ) + R = R.to(org_weight.device, dtype=org_weight.dtype) + + if org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", org_weight, R) else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + ratio * conved * scale - + weight = torch.einsum("oi, op -> pi", org_weight, R) + + weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor + module.weight = torch.nn.Parameter(weight) + with concurrent.futures.ThreadPoolExecutor() as executor: + list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) + def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): base_alphas = {} # alpha for merged model From 57ae44eb6138fe4a3864fffa62090f9d0113417d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 19:45:00 +0900 Subject: [PATCH 2/3] refactor to make safer --- networks/sdxl_merge_lora.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 2c998c8c..d5a54e02 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -44,11 +44,11 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): def detect_method_from_training_model(models, dtype): for model in models: lora_sd, _ = load_state_dict(model, dtype) - for key in tqdm(lora_sd.keys()): - if 'lora_up' in key or 'lora_down' in key: - return 'LoRA' - elif "oft_blocks" in key: - return 'OFT' + for key in tqdm(lora_sd.keys()): + if 'lora_up' in key or 'lora_down' in key: + return 'LoRA' + elif "oft_blocks" in key: + return 'OFT' def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): text_encoder1.to(merge_dtype) @@ -76,6 +76,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ ) elif method == 'OFT': prefix = oft.OFTNetwork.OFT_PREFIX_UNET + # ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY target_replace_modules = ( oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 ) @@ -83,17 +84,11 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): - if method == 'LoRA': - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - name_to_module[lora_name] = child_module - elif method == 'OFT': - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": - oft_name = prefix + "." + name + "." + child_name - oft_name = oft_name.replace(".", "_") - name_to_module[oft_name] = child_module - + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + name_to_module[lora_name] = child_module + for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") @@ -168,6 +163,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ # find original module for this OFT module_name = ".".join(key.split(".")[:-1]) if module_name not in name_to_module: + logger.info(f"no module found for OFT weight: {key}") return module = name_to_module[module_name] @@ -208,7 +204,9 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ module.weight = torch.nn.Parameter(weight) - with concurrent.futures.ThreadPoolExecutor() as executor: + # TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough + max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) From 3387dc7306087b84646666e49323980c89d14945 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 19:45:42 +0900 Subject: [PATCH 3/3] formatting, update README --- README.md | 6 +++ networks/sdxl_merge_lora.py | 86 +++++++++++++++++++++---------------- 2 files changed, 54 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index fd81a781..d5d2a7f7 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### Sep 13, 2024 / 2024-09-13: + +- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). Will be included in the next release. + +- `sdxl_merge_lora.py` が OFT をサポートしました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。次のリリースに含まれます。 + ### Jun 23, 2024 / 2024-06-23: - Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index d5a54e02..d5b6f7f3 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -10,11 +10,14 @@ import library.model_util as model_util import lora import oft from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) import concurrent.futures + def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": sd = load_file(file_name) @@ -41,20 +44,22 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): else: torch.save(model, file_name) + def detect_method_from_training_model(models, dtype): for model in models: lora_sd, _ = load_state_dict(model, dtype) for key in tqdm(lora_sd.keys()): - if 'lora_up' in key or 'lora_down' in key: - return 'LoRA' + if "lora_up" in key or "lora_down" in key: + return "LoRA" elif "oft_blocks" in key: - return 'OFT' + return "OFT" + def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): text_encoder1.to(merge_dtype) text_encoder1.to(merge_dtype) unet.to(merge_dtype) - + # detect the method: OFT or LoRA_module method = detect_method_from_training_model(models, merge_dtype) logger.info(f"method:{method}") @@ -62,7 +67,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ # create module map name_to_module = {} for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): - if method == 'LoRA': + if method == "LoRA": if i <= 1: if i == 0: prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 @@ -72,9 +77,9 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ else: prefix = lora.LoRANetwork.LORA_PREFIX_UNET target_replace_modules = ( - lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 ) - elif method == 'OFT': + elif method == "OFT": prefix = oft.OFTNetwork.OFT_PREFIX_UNET # ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY target_replace_modules = ( @@ -88,15 +93,14 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ lora_name = prefix + "." + name + "." + child_name lora_name = lora_name.replace(".", "_") name_to_module[lora_name] = child_module - - + for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) logger.info(f"merging...") - if method == 'LoRA': + if method == "LoRA": for key in tqdm(lora_sd.keys()): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -139,12 +143,11 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ module.weight = torch.nn.Parameter(weight) - - elif method == 'OFT': - - multiplier=1.0 - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - + elif method == "OFT": + + multiplier = 1.0 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + for key in tqdm(lora_sd.keys()): if "oft_blocks" in key: oft_blocks = lora_sd[key] @@ -154,12 +157,12 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ if "alpha" in key: oft_blocks = lora_sd[key] alpha = oft_blocks.item() - break - + break + def merge_to(key): if "alpha" in key: return - + # find original module for this OFT module_name = ".".join(key.split(".")[:-1]) if module_name not in name_to_module: @@ -168,18 +171,18 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ module = name_to_module[module_name] # logger.info(f"apply {key} to {module}") - + oft_blocks = lora_sd[key] - + if isinstance(module, torch.nn.Linear): out_dim = module.out_features elif isinstance(module, torch.nn.Conv2d): out_dim = module.out_channels - + num_blocks = dim block_size = out_dim // dim constraint = (0 if alpha is None else alpha) * out_dim - + block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=constraint) @@ -188,24 +191,24 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) block_R_weighted = multiplier * block_R + (1 - multiplier) * I R = torch.block_diag(*block_R_weighted) - + # get org weight org_sd = module.state_dict() org_weight = org_sd["weight"].to(device) R = R.to(org_weight.device, dtype=org_weight.dtype) - + if org_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", org_weight, R) else: weight = torch.einsum("oi, op -> pi", org_weight, R) - - weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor - + + weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor + module.weight = torch.nn.Parameter(weight) # TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough - max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU + max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) @@ -258,7 +261,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): for key in tqdm(lora_sd.keys()): if "alpha" in key: continue - + if "lora_up" in key and concat: concat_dim = 1 elif "lora_down" in key and concat: @@ -272,8 +275,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio - scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 - + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + if key in merged_sd: assert ( merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None @@ -295,7 +298,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): dim = merged_sd[key_down].shape[0] perm = torch.randperm(dim) merged_sd[key_down] = merged_sd[key_down][perm] - merged_sd[key_up] = merged_sd[key_up][:,perm] + merged_sd[key_up] = merged_sd[key_up][:, perm] logger.info("merged model") logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") @@ -323,7 +326,9 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert len(args.models) == len( + args.ratios + ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" def str_to_dtype(p): if p == "float": @@ -410,10 +415,16 @@ def setup_parser() -> argparse.ArgumentParser: help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", ) parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", ) parser.add_argument( - "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + "--models", + type=str, + nargs="*", + help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors", ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") parser.add_argument( @@ -431,8 +442,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--shuffle", action="store_true", - help="shuffle lora weight./ " - + "LoRAの重みをシャッフルする", + help="shuffle lora weight./ " + "LoRAの重みをシャッフルする", ) return parser