feat: add conversion script for LoRA models to ComfyUI format with reverse option

This commit is contained in:
Kohya S
2025-09-16 21:48:47 +09:00
parent f318ddaeea
commit cbe2a9da45
2 changed files with 55 additions and 19 deletions

View File

@@ -24,28 +24,47 @@ def main(args):
logger.info(f"Converting...")
# Key mapping tables: (sd-scripts format, ComfyUI format)
double_blocks_mappings = [
("img_mlp_fc1", "img_mlp_0"),
("img_mlp_fc2", "img_mlp_2"),
("img_mod_linear", "img_mod_lin"),
("txt_mlp_fc1", "txt_mlp_0"),
("txt_mlp_fc2", "txt_mlp_2"),
("txt_mod_linear", "txt_mod_lin"),
]
single_blocks_mappings = [
("modulation_linear", "modulation_lin"),
]
keys = list(state_dict.keys())
count = 0
for k in keys:
new_k = k
if "double_blocks" in k:
new_k = (
k.replace("img_mlp_fc1", "img_mlp_0").replace("img_mlp_fc2", "img_mlp_2").replace("img_mod_linear", "img_mod_lin")
)
new_k = (
new_k.replace("txt_mlp_fc1", "txt_mlp_0")
.replace("txt_mlp_fc2", "txt_mlp_2")
.replace("txt_mod_linear", "txt_mod_lin")
)
if new_k != k:
state_dict[new_k] = state_dict.pop(k)
count += 1
# print(f"Renamed {k} to {new_k}")
mappings = double_blocks_mappings
elif "single_blocks" in k:
new_k = k.replace("modulation_linear", "modulation_lin")
if new_k != k:
state_dict[new_k] = state_dict.pop(k)
count += 1
# print(f"Renamed {k} to {new_k}")
mappings = single_blocks_mappings
else:
continue
# Apply mappings based on conversion direction
for src_key, dst_key in mappings:
if args.reverse:
# ComfyUI to sd-scripts: swap src and dst
new_k = new_k.replace(dst_key, src_key)
else:
# sd-scripts to ComfyUI: use as-is
new_k = new_k.replace(src_key, dst_key)
if new_k != k:
state_dict[new_k] = state_dict.pop(k)
count += 1
# print(f"Renamed {k} to {new_k}")
logger.info(f"Converted {count} keys")
# Calculate hash
@@ -64,5 +83,6 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert LoRA format")
parser.add_argument("src_path", type=str, default=None, help="source path, sd-scripts format")
parser.add_argument("dst_path", type=str, default=None, help="destination path, ComfyUI format")
parser.add_argument("--reverse", action="store_true", help="reverse conversion direction")
args = parser.parse_args()
main(args)