mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Update FLUX.1 support for compact models
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from dataclasses import replace
|
||||
import json
|
||||
import os
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@@ -43,8 +44,21 @@ def load_safetensors(
|
||||
return load_file(path) # prevent device invalid Error
|
||||
|
||||
|
||||
def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, List[str]]:
|
||||
# check the state dict: Diffusers or BFL, dev or schnell
|
||||
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
|
||||
"""
|
||||
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
|
||||
|
||||
Args:
|
||||
ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, bool, Tuple[int, int], List[str]]:
|
||||
- bool: Diffusersかどうかを示すフラグ。
|
||||
- bool: Schnellかどうかを示すフラグ。
|
||||
- Tuple[int, int]: ダブルブロックとシングルブロックの数。
|
||||
- List[str]: チェックポイントに含まれるキーのリスト。
|
||||
"""
|
||||
# check the state dict: Diffusers or BFL, dev or schnell, number of blocks
|
||||
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
|
||||
|
||||
if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
|
||||
@@ -61,19 +75,57 @@ def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool,
|
||||
|
||||
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
|
||||
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
|
||||
return is_diffusers, is_schnell, ckpt_paths
|
||||
|
||||
# check number of double and single blocks
|
||||
if not is_diffusers:
|
||||
max_double_block_index = max(
|
||||
[int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")]
|
||||
)
|
||||
max_single_block_index = max(
|
||||
[int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")]
|
||||
)
|
||||
else:
|
||||
max_double_block_index = max(
|
||||
[
|
||||
int(key.split(".")[1])
|
||||
for key in keys
|
||||
if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias")
|
||||
]
|
||||
)
|
||||
max_single_block_index = max(
|
||||
[
|
||||
int(key.split(".")[1])
|
||||
for key in keys
|
||||
if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias")
|
||||
]
|
||||
)
|
||||
|
||||
num_double_blocks = max_double_block_index + 1
|
||||
num_single_blocks = max_single_block_index + 1
|
||||
|
||||
return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths
|
||||
|
||||
|
||||
def load_flow_model(
|
||||
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
|
||||
) -> Tuple[bool, flux_models.Flux]:
|
||||
is_diffusers, is_schnell, ckpt_paths = check_flux_state_dict_diffusers_schnell(ckpt_path)
|
||||
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
|
||||
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
||||
|
||||
# build model
|
||||
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
|
||||
with torch.device("meta"):
|
||||
model = flux_models.Flux(flux_models.configs[name].params)
|
||||
params = flux_models.configs[name].params
|
||||
|
||||
# set the number of blocks
|
||||
if params.depth != num_double_blocks:
|
||||
logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
|
||||
params = replace(params, depth=num_double_blocks)
|
||||
if params.depth_single_blocks != num_single_blocks:
|
||||
logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
|
||||
params = replace(params, depth_single_blocks=num_single_blocks)
|
||||
|
||||
model = flux_models.Flux(params)
|
||||
if dtype is not None:
|
||||
model = model.to(dtype)
|
||||
|
||||
@@ -86,7 +138,7 @@ def load_flow_model(
|
||||
# convert Diffusers to BFL
|
||||
if is_diffusers:
|
||||
logger.info("Converting Diffusers to BFL")
|
||||
sd = convert_diffusers_sd_to_bfl(sd)
|
||||
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
|
||||
logger.info("Converted Diffusers to BFL")
|
||||
|
||||
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||
@@ -349,16 +401,16 @@ BFL_TO_DIFFUSERS_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]:
|
||||
def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]:
|
||||
# make reverse map from diffusers map
|
||||
diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key)
|
||||
for b in range(NUM_DOUBLE_BLOCKS):
|
||||
for b in range(num_double_blocks):
|
||||
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
|
||||
if key.startswith("double_blocks."):
|
||||
block_prefix = f"transformer_blocks.{b}."
|
||||
for i, weight in enumerate(weights):
|
||||
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
|
||||
for b in range(NUM_SINGLE_BLOCKS):
|
||||
for b in range(num_single_blocks):
|
||||
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
|
||||
if key.startswith("single_blocks."):
|
||||
block_prefix = f"single_transformer_blocks.{b}."
|
||||
@@ -371,8 +423,10 @@ def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]:
|
||||
return diffusers_to_bfl_map
|
||||
|
||||
|
||||
def convert_diffusers_sd_to_bfl(diffusers_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
diffusers_to_bfl_map = make_diffusers_to_bfl_map()
|
||||
def convert_diffusers_sd_to_bfl(
|
||||
diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS
|
||||
) -> dict[str, torch.Tensor]:
|
||||
diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks)
|
||||
|
||||
# iterate over three safetensors files to reduce memory usage
|
||||
flux_sd = {}
|
||||
|
||||
Reference in New Issue
Block a user