From ecaea909b10fa8b3eb94a1cf57b26d5daba1683e Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 12 Oct 2024 20:26:57 +0900 Subject: [PATCH 1/2] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 37fc911f..9128bf8d 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ The command to install PyTorch is as follows: Oct 12, 2024: - Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! - - It should work with all training scripts, but it is unverified. + - In simple tests, SDXL and FLUX.1 LoRA training worked. FLUX.1 fine-tuning did not work, probably due to a PyTorch-related error. Other scripts are unverified. - Set up multi-GPU training with `accelerate config`. - Specify `--rdzv_backend=c10d` when launching `accelerate launch`. You can also edit `config.yaml` directly. ``` From e277b5789e791539b5e51187530f11bd94e24871 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 12 Oct 2024 21:49:07 +0900 Subject: [PATCH 2/2] Update FLUX.1 support for compact models --- README.md | 10 ++++++ flux_train.py | 12 +++---- flux_train_network.py | 2 +- library/flux_utils.py | 76 ++++++++++++++++++++++++++++++++++++------- 4 files changed, 82 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 9128bf8d..b64515a1 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,16 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 12, 2024 (update 1): + +- [Experimental] FLUX.1 fine-tuning and LoRA training now support "FLUX.1 __compact__" models. + - A compact model is a model that retains the FLUX.1 architecture but reduces the number of double/single blocks from the default 19/38. + - The model is automatically determined based on the keys in *.safetensors. + - Specifications for compact model safetensors: + - Please specify the block indices as consecutive numbers. An error will occur if there are missing numbers. For example, if you reduce the double blocks to 15, the maximum key will be `double_blocks.14.*`. The same applies to single blocks. + - LoRA training is unverified. + - The trained model can be used for inference with `flux_minimal_inference.py`. Other inference environments are unverified. + Oct 12, 2024: - Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! diff --git a/flux_train.py b/flux_train.py index ecc87c0a..2fc13068 100644 --- a/flux_train.py +++ b/flux_train.py @@ -137,7 +137,7 @@ def train(args): train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 - _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) if args.debug_dataset: if args.cache_text_encoder_outputs: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( @@ -181,7 +181,7 @@ def train(args): # load VAE for caching latents ae = None if cache_latents: - ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) ae.to(accelerator.device, dtype=weight_dtype) ae.requires_grad_(False) ae.eval() @@ -510,8 +510,8 @@ def train(args): library.adafactor_fused.patch_adafactor_fused(optimizer) blocks_to_swap = args.blocks_to_swap - num_double_blocks = 19 # len(flux.double_blocks) - num_single_blocks = 38 # len(flux.single_blocks) + num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) + num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) num_block_units = num_double_blocks + num_single_blocks // 2 handled_unit_indices = set() @@ -603,8 +603,8 @@ def train(args): parameter_optimizer_map = {} blocks_to_swap = args.blocks_to_swap - num_double_blocks = 19 # len(flux.double_blocks) - num_single_blocks = 38 # len(flux.single_blocks) + num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) + num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) num_block_units = num_double_blocks + num_single_blocks // 2 n = 1 # only asynchronous purpose, no need to increase this number diff --git a/flux_train_network.py b/flux_train_network.py index 5d14bd28..a24c1905 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -139,7 +139,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return flux_lower def get_tokenize_strategy(self, args): - _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) if args.t5xxl_max_token_length is None: if is_schnell: diff --git a/library/flux_utils.py b/library/flux_utils.py index 713814e2..7a1ec37b 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,3 +1,4 @@ +from dataclasses import replace import json import os from typing import List, Optional, Tuple, Union @@ -43,8 +44,21 @@ def load_safetensors( return load_file(path) # prevent device invalid Error -def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, List[str]]: - # check the state dict: Diffusers or BFL, dev or schnell +def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: + """ + チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。 + + Args: + ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。 + + Returns: + Tuple[bool, bool, Tuple[int, int], List[str]]: + - bool: Diffusersかどうかを示すフラグ。 + - bool: Schnellかどうかを示すフラグ。 + - Tuple[int, int]: ダブルブロックとシングルブロックの数。 + - List[str]: チェックポイントに含まれるキーのリスト。 + """ + # check the state dict: Diffusers or BFL, dev or schnell, number of blocks logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell") if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers @@ -61,19 +75,57 @@ def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) - return is_diffusers, is_schnell, ckpt_paths + + # check number of double and single blocks + if not is_diffusers: + max_double_block_index = max( + [int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")] + ) + max_single_block_index = max( + [int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")] + ) + else: + max_double_block_index = max( + [ + int(key.split(".")[1]) + for key in keys + if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias") + ] + ) + max_single_block_index = max( + [ + int(key.split(".")[1]) + for key in keys + if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias") + ] + ) + + num_double_blocks = max_double_block_index + 1 + num_single_blocks = max_single_block_index + 1 + + return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths def load_flow_model( ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False ) -> Tuple[bool, flux_models.Flux]: - is_diffusers, is_schnell, ckpt_paths = check_flux_state_dict_diffusers_schnell(ckpt_path) + is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL # build model logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") with torch.device("meta"): - model = flux_models.Flux(flux_models.configs[name].params) + params = flux_models.configs[name].params + + # set the number of blocks + if params.depth != num_double_blocks: + logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}") + params = replace(params, depth=num_double_blocks) + if params.depth_single_blocks != num_single_blocks: + logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}") + params = replace(params, depth_single_blocks=num_single_blocks) + + model = flux_models.Flux(params) if dtype is not None: model = model.to(dtype) @@ -86,7 +138,7 @@ def load_flow_model( # convert Diffusers to BFL if is_diffusers: logger.info("Converting Diffusers to BFL") - sd = convert_diffusers_sd_to_bfl(sd) + sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) logger.info("Converted Diffusers to BFL") info = model.load_state_dict(sd, strict=False, assign=True) @@ -349,16 +401,16 @@ BFL_TO_DIFFUSERS_MAP = { } -def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]: +def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]: # make reverse map from diffusers map diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) - for b in range(NUM_DOUBLE_BLOCKS): + for b in range(num_double_blocks): for key, weights in BFL_TO_DIFFUSERS_MAP.items(): if key.startswith("double_blocks."): block_prefix = f"transformer_blocks.{b}." for i, weight in enumerate(weights): diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) - for b in range(NUM_SINGLE_BLOCKS): + for b in range(num_single_blocks): for key, weights in BFL_TO_DIFFUSERS_MAP.items(): if key.startswith("single_blocks."): block_prefix = f"single_transformer_blocks.{b}." @@ -371,8 +423,10 @@ def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]: return diffusers_to_bfl_map -def convert_diffusers_sd_to_bfl(diffusers_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - diffusers_to_bfl_map = make_diffusers_to_bfl_map() +def convert_diffusers_sd_to_bfl( + diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS +) -> dict[str, torch.Tensor]: + diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks) # iterate over three safetensors files to reduce memory usage flux_sd = {}