mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'sd3' into multi-gpu-caching
This commit is contained in:
12
README.md
12
README.md
@@ -21,10 +21,20 @@ Oct 12, 2024 (update 1):
|
|||||||
- Even if this option is specified, the cache will be created if the file does not exist.
|
- Even if this option is specified, the cache will be created if the file does not exist.
|
||||||
- `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead.
|
- `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead.
|
||||||
|
|
||||||
|
Oct 12, 2024 (update 1):
|
||||||
|
|
||||||
|
- [Experimental] FLUX.1 fine-tuning and LoRA training now support "FLUX.1 __compact__" models.
|
||||||
|
- A compact model is a model that retains the FLUX.1 architecture but reduces the number of double/single blocks from the default 19/38.
|
||||||
|
- The model is automatically determined based on the keys in *.safetensors.
|
||||||
|
- Specifications for compact model safetensors:
|
||||||
|
- Please specify the block indices as consecutive numbers. An error will occur if there are missing numbers. For example, if you reduce the double blocks to 15, the maximum key will be `double_blocks.14.*`. The same applies to single blocks.
|
||||||
|
- LoRA training is unverified.
|
||||||
|
- The trained model can be used for inference with `flux_minimal_inference.py`. Other inference environments are unverified.
|
||||||
|
|
||||||
Oct 12, 2024:
|
Oct 12, 2024:
|
||||||
|
|
||||||
- Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)!
|
- Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)!
|
||||||
- It should work with all training scripts, but it is unverified.
|
- In simple tests, SDXL and FLUX.1 LoRA training worked. FLUX.1 fine-tuning did not work, probably due to a PyTorch-related error. Other scripts are unverified.
|
||||||
- Set up multi-GPU training with `accelerate config`.
|
- Set up multi-GPU training with `accelerate config`.
|
||||||
- Specify `--rdzv_backend=c10d` when launching `accelerate launch`. You can also edit `config.yaml` directly.
|
- Specify `--rdzv_backend=c10d` when launching `accelerate launch`. You can also edit `config.yaml` directly.
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ def train(args):
|
|||||||
|
|
||||||
train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
|
train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
|
||||||
|
|
||||||
_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path)
|
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||||
@@ -514,8 +514,8 @@ def train(args):
|
|||||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||||
|
|
||||||
blocks_to_swap = args.blocks_to_swap
|
blocks_to_swap = args.blocks_to_swap
|
||||||
num_double_blocks = 19 # len(flux.double_blocks)
|
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
|
||||||
num_single_blocks = 38 # len(flux.single_blocks)
|
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
|
||||||
num_block_units = num_double_blocks + num_single_blocks // 2
|
num_block_units = num_double_blocks + num_single_blocks // 2
|
||||||
handled_unit_indices = set()
|
handled_unit_indices = set()
|
||||||
|
|
||||||
@@ -607,8 +607,8 @@ def train(args):
|
|||||||
parameter_optimizer_map = {}
|
parameter_optimizer_map = {}
|
||||||
|
|
||||||
blocks_to_swap = args.blocks_to_swap
|
blocks_to_swap = args.blocks_to_swap
|
||||||
num_double_blocks = 19 # len(flux.double_blocks)
|
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
|
||||||
num_single_blocks = 38 # len(flux.single_blocks)
|
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
|
||||||
num_block_units = num_double_blocks + num_single_blocks // 2
|
num_block_units = num_double_blocks + num_single_blocks // 2
|
||||||
|
|
||||||
n = 1 # only asynchronous purpose, no need to increase this number
|
n = 1 # only asynchronous purpose, no need to increase this number
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
return flux_lower
|
return flux_lower
|
||||||
|
|
||||||
def get_tokenize_strategy(self, args):
|
def get_tokenize_strategy(self, args):
|
||||||
_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path)
|
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
||||||
|
|
||||||
if args.t5xxl_max_token_length is None:
|
if args.t5xxl_max_token_length is None:
|
||||||
if is_schnell:
|
if is_schnell:
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from dataclasses import replace
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
@@ -43,8 +44,21 @@ def load_safetensors(
|
|||||||
return load_file(path) # prevent device invalid Error
|
return load_file(path) # prevent device invalid Error
|
||||||
|
|
||||||
|
|
||||||
def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, List[str]]:
|
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
|
||||||
# check the state dict: Diffusers or BFL, dev or schnell
|
"""
|
||||||
|
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, bool, Tuple[int, int], List[str]]:
|
||||||
|
- bool: Diffusersかどうかを示すフラグ。
|
||||||
|
- bool: Schnellかどうかを示すフラグ。
|
||||||
|
- Tuple[int, int]: ダブルブロックとシングルブロックの数。
|
||||||
|
- List[str]: チェックポイントに含まれるキーのリスト。
|
||||||
|
"""
|
||||||
|
# check the state dict: Diffusers or BFL, dev or schnell, number of blocks
|
||||||
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
|
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
|
||||||
|
|
||||||
if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
|
if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
|
||||||
@@ -61,19 +75,57 @@ def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool,
|
|||||||
|
|
||||||
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
|
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
|
||||||
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
|
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
|
||||||
return is_diffusers, is_schnell, ckpt_paths
|
|
||||||
|
# check number of double and single blocks
|
||||||
|
if not is_diffusers:
|
||||||
|
max_double_block_index = max(
|
||||||
|
[int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")]
|
||||||
|
)
|
||||||
|
max_single_block_index = max(
|
||||||
|
[int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
max_double_block_index = max(
|
||||||
|
[
|
||||||
|
int(key.split(".")[1])
|
||||||
|
for key in keys
|
||||||
|
if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
max_single_block_index = max(
|
||||||
|
[
|
||||||
|
int(key.split(".")[1])
|
||||||
|
for key in keys
|
||||||
|
if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
num_double_blocks = max_double_block_index + 1
|
||||||
|
num_single_blocks = max_single_block_index + 1
|
||||||
|
|
||||||
|
return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths
|
||||||
|
|
||||||
|
|
||||||
def load_flow_model(
|
def load_flow_model(
|
||||||
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
|
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
|
||||||
) -> Tuple[bool, flux_models.Flux]:
|
) -> Tuple[bool, flux_models.Flux]:
|
||||||
is_diffusers, is_schnell, ckpt_paths = check_flux_state_dict_diffusers_schnell(ckpt_path)
|
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
|
||||||
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
|
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
model = flux_models.Flux(flux_models.configs[name].params)
|
params = flux_models.configs[name].params
|
||||||
|
|
||||||
|
# set the number of blocks
|
||||||
|
if params.depth != num_double_blocks:
|
||||||
|
logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
|
||||||
|
params = replace(params, depth=num_double_blocks)
|
||||||
|
if params.depth_single_blocks != num_single_blocks:
|
||||||
|
logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
|
||||||
|
params = replace(params, depth_single_blocks=num_single_blocks)
|
||||||
|
|
||||||
|
model = flux_models.Flux(params)
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
model = model.to(dtype)
|
model = model.to(dtype)
|
||||||
|
|
||||||
@@ -86,7 +138,7 @@ def load_flow_model(
|
|||||||
# convert Diffusers to BFL
|
# convert Diffusers to BFL
|
||||||
if is_diffusers:
|
if is_diffusers:
|
||||||
logger.info("Converting Diffusers to BFL")
|
logger.info("Converting Diffusers to BFL")
|
||||||
sd = convert_diffusers_sd_to_bfl(sd)
|
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
|
||||||
logger.info("Converted Diffusers to BFL")
|
logger.info("Converted Diffusers to BFL")
|
||||||
|
|
||||||
info = model.load_state_dict(sd, strict=False, assign=True)
|
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||||
@@ -349,16 +401,16 @@ BFL_TO_DIFFUSERS_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]:
|
def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]:
|
||||||
# make reverse map from diffusers map
|
# make reverse map from diffusers map
|
||||||
diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key)
|
diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key)
|
||||||
for b in range(NUM_DOUBLE_BLOCKS):
|
for b in range(num_double_blocks):
|
||||||
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
|
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
|
||||||
if key.startswith("double_blocks."):
|
if key.startswith("double_blocks."):
|
||||||
block_prefix = f"transformer_blocks.{b}."
|
block_prefix = f"transformer_blocks.{b}."
|
||||||
for i, weight in enumerate(weights):
|
for i, weight in enumerate(weights):
|
||||||
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
|
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
|
||||||
for b in range(NUM_SINGLE_BLOCKS):
|
for b in range(num_single_blocks):
|
||||||
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
|
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
|
||||||
if key.startswith("single_blocks."):
|
if key.startswith("single_blocks."):
|
||||||
block_prefix = f"single_transformer_blocks.{b}."
|
block_prefix = f"single_transformer_blocks.{b}."
|
||||||
@@ -371,8 +423,10 @@ def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]:
|
|||||||
return diffusers_to_bfl_map
|
return diffusers_to_bfl_map
|
||||||
|
|
||||||
|
|
||||||
def convert_diffusers_sd_to_bfl(diffusers_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
def convert_diffusers_sd_to_bfl(
|
||||||
diffusers_to_bfl_map = make_diffusers_to_bfl_map()
|
diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks)
|
||||||
|
|
||||||
# iterate over three safetensors files to reduce memory usage
|
# iterate over three safetensors files to reduce memory usage
|
||||||
flux_sd = {}
|
flux_sd = {}
|
||||||
|
|||||||
Reference in New Issue
Block a user