load Diffusers format, check schnell/dev

This commit is contained in:
Kohya S
2024-10-06 21:32:21 +09:00
parent ba08a89894
commit 83e3048cb0
6 changed files with 196 additions and 111 deletions

View File

@@ -29,6 +29,7 @@ from safetensors.torch import safe_open
import torch
from tqdm import tqdm
from library import flux_utils
from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file
setup_logging()
@@ -36,65 +37,6 @@ import logging
logger = logging.getLogger(__name__)
NUM_DOUBLE_BLOCKS = 19
NUM_SINGLE_BLOCKS = 38
BFL_TO_DIFFUSERS_MAP = {
"time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
"time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
"time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
"time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
"vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
"vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
"vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
"vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
"guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
"guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
"guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
"guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
"txt_in.weight": ["context_embedder.weight"],
"txt_in.bias": ["context_embedder.bias"],
"img_in.weight": ["x_embedder.weight"],
"img_in.bias": ["x_embedder.bias"],
"double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
"double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
"double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
"double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
"double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
"double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
"double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
"double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
"double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
"double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"],
"double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
"double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"],
"double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
"double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"],
"double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
"double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"],
"double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
"double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"],
"double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
"double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"],
"double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
"double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"],
"double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
"double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"],
"single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
"single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
"single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
"single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
"single_blocks.().linear2.weight": ["proj_out.weight"],
"single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
"single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"],
"single_blocks.().linear2.weight": ["proj_out.weight"],
"single_blocks.().linear2.bias": ["proj_out.bias"],
"final_layer.linear.weight": ["proj_out.weight"],
"final_layer.linear.bias": ["proj_out.bias"],
"final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
"final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
}
def convert(args):
# if diffusers_path is folder, get safetensors file
@@ -114,23 +56,7 @@ def convert(args):
save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None
# make reverse map from diffusers map
diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key)
for b in range(NUM_DOUBLE_BLOCKS):
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("double_blocks."):
block_prefix = f"transformer_blocks.{b}."
for i, weight in enumerate(weights):
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
for b in range(NUM_SINGLE_BLOCKS):
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("single_blocks."):
block_prefix = f"single_transformer_blocks.{b}."
for i, weight in enumerate(weights):
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")):
for i, weight in enumerate(weights):
diffusers_to_bfl_map[weight] = (i, key)
diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map()
# iterate over three safetensors files to reduce memory usage
flux_sd = {}