diff --git a/networks/convert_anima_lora_to_comfy.py b/networks/convert_anima_lora_to_comfy.py new file mode 100644 index 00000000..880a2f78 --- /dev/null +++ b/networks/convert_anima_lora_to_comfy.py @@ -0,0 +1,156 @@ +import argparse +from safetensors.torch import save_file +from safetensors import safe_open + + +from library import train_util +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def main(args): + # load source safetensors + logger.info(f"Loading source file {args.src_path}") + state_dict = {} + with safe_open(args.src_path, framework="pt") as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + + logger.info(f"Converting...") + + keys = list(state_dict.keys()) + count = 0 + + for k in keys: + if not args.reverse: + is_dit_lora = k.startswith("lora_unet_") + module_and_weight_name = "_".join(k.split("_")[2:]) # Remove `lora_unet_`or `lora_te_` prefix + + # Split at the first dot, e.g., "block1_linear.weight" -> "block1_linear", "weight" + module_name, weight_name = module_and_weight_name.split(".", 1) + + # Weight name conversion: lora_up/lora_down to lora_A/lora_B + if weight_name.startswith("lora_up"): + weight_name = weight_name.replace("lora_up", "lora_B") + elif weight_name.startswith("lora_down"): + weight_name = weight_name.replace("lora_down", "lora_A") + else: + # Keep other weight names as-is: e.g. alpha + pass + + # Module name conversion: convert dots to underscores + original_module_name = module_name.replace("_", ".") # Convert to dot notation + + # Convert back illegal dots in module names + # DiT + original_module_name = original_module_name.replace(".linear.", ".linear_") + original_module_name = original_module_name.replace("t.embedding.norm", "t_embedding_norm") + original_module_name = original_module_name.replace("x.embedder", "x_embedder") + original_module_name = original_module_name.replace("adaln.modulation.cross_attn", "adaln_modulation_cross_attn") + original_module_name = original_module_name.replace("adaln.modulation.mlp", "adaln_modulation_mlp") + original_module_name = original_module_name.replace("cross.attn", "cross_attn") + original_module_name = original_module_name.replace("k.proj", "k_proj") + original_module_name = original_module_name.replace("k.norm", "k_norm") + original_module_name = original_module_name.replace("q.proj", "q_proj") + original_module_name = original_module_name.replace("q.norm", "q_norm") + original_module_name = original_module_name.replace("v.proj", "v_proj") + original_module_name = original_module_name.replace("o.proj", "o_proj") + original_module_name = original_module_name.replace("output.proj", "output_proj") + original_module_name = original_module_name.replace("self.attn", "self_attn") + original_module_name = original_module_name.replace("final.layer", "final_layer") + original_module_name = original_module_name.replace("adaln.modulation", "adaln_modulation") + original_module_name = original_module_name.replace("norm.cross.attn", "norm_cross_attn") + original_module_name = original_module_name.replace("norm.mlp", "norm_mlp") + original_module_name = original_module_name.replace("norm.self.attn", "norm_self_attn") + original_module_name = original_module_name.replace("out.proj", "out_proj") + + # Qwen3 + original_module_name = original_module_name.replace("embed.tokens", "embed_tokens") + original_module_name = original_module_name.replace("input.layernorm", "input_layernorm") + original_module_name = original_module_name.replace("down.proj", "down_proj") + original_module_name = original_module_name.replace("gate.proj", "gate_proj") + original_module_name = original_module_name.replace("up.proj", "up_proj") + original_module_name = original_module_name.replace("post.attention.layernorm", "post_attention_layernorm") + + # Prefix conversion + new_prefix = "diffusion_model." if is_dit_lora else "text_encoder.qwen3." + + new_k = f"{new_prefix}{original_module_name}.{weight_name}" + else: + if k.startswith("diffusion_model."): + is_dit_lora = True + module_and_weight_name = k[len("diffusion_model.") :] + elif k.startswith("text_encoder.qwen3."): + is_dit_lora = False + module_and_weight_name = k[len("text_encoder.qwen3.") :] + else: + logger.warning(f"Skipping unrecognized key {k}") + continue + + # Get weight name + if ".lora_" in module_and_weight_name: + module_name, weight_name = module_and_weight_name.rsplit(".lora_", 1) + weight_name = "lora_" + weight_name + else: + module_name, weight_name = module_and_weight_name.rsplit(".", 1) # Keep other weight names as-is: e.g. alpha + + # Weight name conversion: lora_A/lora_B to lora_up/lora_down + # Note: we only convert lora_A and lora_B weights, other weights are kept as-is + if weight_name.startswith("lora_B"): + weight_name = weight_name.replace("lora_B", "lora_up") + elif weight_name.startswith("lora_A"): + weight_name = weight_name.replace("lora_A", "lora_down") + + # Module name conversion: convert dots to underscores + module_name = module_name.replace(".", "_") # Convert to underscore notation + + # Prefix conversion + prefix = "lora_unet_" if is_dit_lora else "lora_te_" + + new_k = f"{prefix}{module_name}.{weight_name}" + + state_dict[new_k] = state_dict.pop(k) + count += 1 + + logger.info(f"Converted {count} keys") + if count == 0: + logger.warning("No keys were converted. Please check if the source file is in the expected format.") + elif count > 0 and count < len(keys): + logger.warning( + f"Only {count} out of {len(keys)} keys were converted. Please check if there are unexpected keys in the source file." + ) + + # Calculate hash + if metadata is not None: + logger.info(f"Calculating hashes and creating metadata...") + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + # save destination safetensors + logger.info(f"Saving destination file {args.dst_path}") + save_file(state_dict, args.dst_path, metadata=metadata) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert LoRA format") + parser.add_argument( + "src_path", + type=str, + default=None, + help="source path, sd-scripts format (or ComfyUI compatible format if --reverse is set, only supported for LoRAs converted by this script)", + ) + parser.add_argument( + "dst_path", + type=str, + default=None, + help="destination path, ComfyUI compatible format (or sd-scripts format if --reverse is set)", + ) + parser.add_argument("--reverse", action="store_true", help="reverse conversion direction") + args = parser.parse_args() + main(args)