diff --git a/README.md b/README.md index 2b256283..544c665d 100644 --- a/README.md +++ b/README.md @@ -21,10 +21,20 @@ Oct 12, 2024 (update 1): - Even if this option is specified, the cache will be created if the file does not exist. - `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead. +Oct 12, 2024 (update 1): + +- [Experimental] FLUX.1 fine-tuning and LoRA training now support "FLUX.1 __compact__" models. + - A compact model is a model that retains the FLUX.1 architecture but reduces the number of double/single blocks from the default 19/38. + - The model is automatically determined based on the keys in *.safetensors. + - Specifications for compact model safetensors: + - Please specify the block indices as consecutive numbers. An error will occur if there are missing numbers. For example, if you reduce the double blocks to 15, the maximum key will be `double_blocks.14.*`. The same applies to single blocks. + - LoRA training is unverified. + - The trained model can be used for inference with `flux_minimal_inference.py`. Other inference environments are unverified. + Oct 12, 2024: - Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! - - It should work with all training scripts, but it is unverified. + - In simple tests, SDXL and FLUX.1 LoRA training worked. FLUX.1 fine-tuning did not work, probably due to a PyTorch-related error. Other scripts are unverified. - Set up multi-GPU training with `accelerate config`. - Specify `--rdzv_backend=c10d` when launching `accelerate launch`. You can also edit `config.yaml` directly. ``` diff --git a/flux_train.py b/flux_train.py index e18a9244..46a8babd 100644 --- a/flux_train.py +++ b/flux_train.py @@ -141,7 +141,7 @@ def train(args): train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 - _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) if args.debug_dataset: if args.cache_text_encoder_outputs: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( @@ -514,8 +514,8 @@ def train(args): library.adafactor_fused.patch_adafactor_fused(optimizer) blocks_to_swap = args.blocks_to_swap - num_double_blocks = 19 # len(flux.double_blocks) - num_single_blocks = 38 # len(flux.single_blocks) + num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) + num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) num_block_units = num_double_blocks + num_single_blocks // 2 handled_unit_indices = set() @@ -607,8 +607,8 @@ def train(args): parameter_optimizer_map = {} blocks_to_swap = args.blocks_to_swap - num_double_blocks = 19 # len(flux.double_blocks) - num_single_blocks = 38 # len(flux.single_blocks) + num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) + num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) num_block_units = num_double_blocks + num_single_blocks // 2 n = 1 # only asynchronous purpose, no need to increase this number diff --git a/flux_train_network.py b/flux_train_network.py index 3bd8316d..aa92fe3a 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -139,7 +139,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return flux_lower def get_tokenize_strategy(self, args): - _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) if args.t5xxl_max_token_length is None: if is_schnell: diff --git a/library/flux_utils.py b/library/flux_utils.py index 713814e2..7a1ec37b 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,3 +1,4 @@ +from dataclasses import replace import json import os from typing import List, Optional, Tuple, Union @@ -43,8 +44,21 @@ def load_safetensors( return load_file(path) # prevent device invalid Error -def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, List[str]]: - # check the state dict: Diffusers or BFL, dev or schnell +def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: + """ + チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。 + + Args: + ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。 + + Returns: + Tuple[bool, bool, Tuple[int, int], List[str]]: + - bool: Diffusersかどうかを示すフラグ。 + - bool: Schnellかどうかを示すフラグ。 + - Tuple[int, int]: ダブルブロックとシングルブロックの数。 + - List[str]: チェックポイントに含まれるキーのリスト。 + """ + # check the state dict: Diffusers or BFL, dev or schnell, number of blocks logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell") if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers @@ -61,19 +75,57 @@ def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) - return is_diffusers, is_schnell, ckpt_paths + + # check number of double and single blocks + if not is_diffusers: + max_double_block_index = max( + [int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")] + ) + max_single_block_index = max( + [int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")] + ) + else: + max_double_block_index = max( + [ + int(key.split(".")[1]) + for key in keys + if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias") + ] + ) + max_single_block_index = max( + [ + int(key.split(".")[1]) + for key in keys + if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias") + ] + ) + + num_double_blocks = max_double_block_index + 1 + num_single_blocks = max_single_block_index + 1 + + return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths def load_flow_model( ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False ) -> Tuple[bool, flux_models.Flux]: - is_diffusers, is_schnell, ckpt_paths = check_flux_state_dict_diffusers_schnell(ckpt_path) + is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL # build model logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") with torch.device("meta"): - model = flux_models.Flux(flux_models.configs[name].params) + params = flux_models.configs[name].params + + # set the number of blocks + if params.depth != num_double_blocks: + logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}") + params = replace(params, depth=num_double_blocks) + if params.depth_single_blocks != num_single_blocks: + logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}") + params = replace(params, depth_single_blocks=num_single_blocks) + + model = flux_models.Flux(params) if dtype is not None: model = model.to(dtype) @@ -86,7 +138,7 @@ def load_flow_model( # convert Diffusers to BFL if is_diffusers: logger.info("Converting Diffusers to BFL") - sd = convert_diffusers_sd_to_bfl(sd) + sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) logger.info("Converted Diffusers to BFL") info = model.load_state_dict(sd, strict=False, assign=True) @@ -349,16 +401,16 @@ BFL_TO_DIFFUSERS_MAP = { } -def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]: +def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]: # make reverse map from diffusers map diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) - for b in range(NUM_DOUBLE_BLOCKS): + for b in range(num_double_blocks): for key, weights in BFL_TO_DIFFUSERS_MAP.items(): if key.startswith("double_blocks."): block_prefix = f"transformer_blocks.{b}." for i, weight in enumerate(weights): diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) - for b in range(NUM_SINGLE_BLOCKS): + for b in range(num_single_blocks): for key, weights in BFL_TO_DIFFUSERS_MAP.items(): if key.startswith("single_blocks."): block_prefix = f"single_transformer_blocks.{b}." @@ -371,8 +423,10 @@ def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]: return diffusers_to_bfl_map -def convert_diffusers_sd_to_bfl(diffusers_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - diffusers_to_bfl_map = make_diffusers_to_bfl_map() +def convert_diffusers_sd_to_bfl( + diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS +) -> dict[str, torch.Tensor]: + diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks) # iterate over three safetensors files to reduce memory usage flux_sd = {}