Merge branch 'sd3' into multi-gpu-caching

This commit is contained in:
kohya-ss
2024-10-13 11:52:42 +09:00
4 changed files with 82 additions and 18 deletions

View File

@@ -21,10 +21,20 @@ Oct 12, 2024 (update 1):
- Even if this option is specified, the cache will be created if the file does not exist. - Even if this option is specified, the cache will be created if the file does not exist.
- `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead. - `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead.
Oct 12, 2024 (update 1):
- [Experimental] FLUX.1 fine-tuning and LoRA training now support "FLUX.1 __compact__" models.
- A compact model is a model that retains the FLUX.1 architecture but reduces the number of double/single blocks from the default 19/38.
- The model is automatically determined based on the keys in *.safetensors.
- Specifications for compact model safetensors:
- Please specify the block indices as consecutive numbers. An error will occur if there are missing numbers. For example, if you reduce the double blocks to 15, the maximum key will be `double_blocks.14.*`. The same applies to single blocks.
- LoRA training is unverified.
- The trained model can be used for inference with `flux_minimal_inference.py`. Other inference environments are unverified.
Oct 12, 2024: Oct 12, 2024:
- Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! - Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)!
- It should work with all training scripts, but it is unverified. - In simple tests, SDXL and FLUX.1 LoRA training worked. FLUX.1 fine-tuning did not work, probably due to a PyTorch-related error. Other scripts are unverified.
- Set up multi-GPU training with `accelerate config`. - Set up multi-GPU training with `accelerate config`.
- Specify `--rdzv_backend=c10d` when launching `accelerate launch`. You can also edit `config.yaml` directly. - Specify `--rdzv_backend=c10d` when launching `accelerate launch`. You can also edit `config.yaml` directly.
``` ```

View File

@@ -141,7 +141,7 @@ def train(args):
train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
if args.debug_dataset: if args.debug_dataset:
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
@@ -514,8 +514,8 @@ def train(args):
library.adafactor_fused.patch_adafactor_fused(optimizer) library.adafactor_fused.patch_adafactor_fused(optimizer)
blocks_to_swap = args.blocks_to_swap blocks_to_swap = args.blocks_to_swap
num_double_blocks = 19 # len(flux.double_blocks) num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = 38 # len(flux.single_blocks) num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2 num_block_units = num_double_blocks + num_single_blocks // 2
handled_unit_indices = set() handled_unit_indices = set()
@@ -607,8 +607,8 @@ def train(args):
parameter_optimizer_map = {} parameter_optimizer_map = {}
blocks_to_swap = args.blocks_to_swap blocks_to_swap = args.blocks_to_swap
num_double_blocks = 19 # len(flux.double_blocks) num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = 38 # len(flux.single_blocks) num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2 num_block_units = num_double_blocks + num_single_blocks // 2
n = 1 # only asynchronous purpose, no need to increase this number n = 1 # only asynchronous purpose, no need to increase this number

View File

@@ -139,7 +139,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
return flux_lower return flux_lower
def get_tokenize_strategy(self, args): def get_tokenize_strategy(self, args):
_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
if args.t5xxl_max_token_length is None: if args.t5xxl_max_token_length is None:
if is_schnell: if is_schnell:

View File

@@ -1,3 +1,4 @@
from dataclasses import replace
import json import json
import os import os
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -43,8 +44,21 @@ def load_safetensors(
return load_file(path) # prevent device invalid Error return load_file(path) # prevent device invalid Error
def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, List[str]]: def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
# check the state dict: Diffusers or BFL, dev or schnell """
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
Args:
ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。
Returns:
Tuple[bool, bool, Tuple[int, int], List[str]]:
- bool: Diffusersかどうかを示すフラグ。
- bool: Schnellかどうかを示すフラグ。
- Tuple[int, int]: ダブルブロックとシングルブロックの数。
- List[str]: チェックポイントに含まれるキーのリスト。
"""
# check the state dict: Diffusers or BFL, dev or schnell, number of blocks
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell") logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
@@ -61,19 +75,57 @@ def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool,
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
return is_diffusers, is_schnell, ckpt_paths
# check number of double and single blocks
if not is_diffusers:
max_double_block_index = max(
[int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")]
)
max_single_block_index = max(
[int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")]
)
else:
max_double_block_index = max(
[
int(key.split(".")[1])
for key in keys
if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias")
]
)
max_single_block_index = max(
[
int(key.split(".")[1])
for key in keys
if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias")
]
)
num_double_blocks = max_double_block_index + 1
num_single_blocks = max_single_block_index + 1
return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths
def load_flow_model( def load_flow_model(
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
) -> Tuple[bool, flux_models.Flux]: ) -> Tuple[bool, flux_models.Flux]:
is_diffusers, is_schnell, ckpt_paths = check_flux_state_dict_diffusers_schnell(ckpt_path) is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
# build model # build model
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
with torch.device("meta"): with torch.device("meta"):
model = flux_models.Flux(flux_models.configs[name].params) params = flux_models.configs[name].params
# set the number of blocks
if params.depth != num_double_blocks:
logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
params = replace(params, depth=num_double_blocks)
if params.depth_single_blocks != num_single_blocks:
logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
params = replace(params, depth_single_blocks=num_single_blocks)
model = flux_models.Flux(params)
if dtype is not None: if dtype is not None:
model = model.to(dtype) model = model.to(dtype)
@@ -86,7 +138,7 @@ def load_flow_model(
# convert Diffusers to BFL # convert Diffusers to BFL
if is_diffusers: if is_diffusers:
logger.info("Converting Diffusers to BFL") logger.info("Converting Diffusers to BFL")
sd = convert_diffusers_sd_to_bfl(sd) sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL") logger.info("Converted Diffusers to BFL")
info = model.load_state_dict(sd, strict=False, assign=True) info = model.load_state_dict(sd, strict=False, assign=True)
@@ -349,16 +401,16 @@ BFL_TO_DIFFUSERS_MAP = {
} }
def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]: def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]:
# make reverse map from diffusers map # make reverse map from diffusers map
diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key)
for b in range(NUM_DOUBLE_BLOCKS): for b in range(num_double_blocks):
for key, weights in BFL_TO_DIFFUSERS_MAP.items(): for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("double_blocks."): if key.startswith("double_blocks."):
block_prefix = f"transformer_blocks.{b}." block_prefix = f"transformer_blocks.{b}."
for i, weight in enumerate(weights): for i, weight in enumerate(weights):
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
for b in range(NUM_SINGLE_BLOCKS): for b in range(num_single_blocks):
for key, weights in BFL_TO_DIFFUSERS_MAP.items(): for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("single_blocks."): if key.startswith("single_blocks."):
block_prefix = f"single_transformer_blocks.{b}." block_prefix = f"single_transformer_blocks.{b}."
@@ -371,8 +423,10 @@ def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]:
return diffusers_to_bfl_map return diffusers_to_bfl_map
def convert_diffusers_sd_to_bfl(diffusers_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: def convert_diffusers_sd_to_bfl(
diffusers_to_bfl_map = make_diffusers_to_bfl_map() diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS
) -> dict[str, torch.Tensor]:
diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks)
# iterate over three safetensors files to reduce memory usage # iterate over three safetensors files to reduce memory usage
flux_sd = {} flux_sd = {}