Merge pull request #1690 from kohya-ss/multi-gpu-caching

Caching latents and Text Encoder outputs with multiple GPUs
This commit is contained in:
Kohya S.
2024-10-13 19:25:59 +09:00
committed by GitHub
17 changed files with 346 additions and 258 deletions

View File

@@ -11,6 +11,21 @@ The command to install PyTorch is as follows:
### Recent Updates
Oct 13, 2024:
- Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large.
- During multi-GPU training, caching of latents and Text Encoder outputs is now done in multi-GPU.
- Please make sure that `--highvram` and `--vae_batch_size` are specified correctly. If you have enough VRAM, you can increase the batch size to speed up the caching.
- `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching.
- Multi-threading is also implemented for caching of latents. This may speed up the caching process about 5% (depends on the environment).
- `tools/cache_latents.py` and `tools/cache_text_encoder_outputs.py` also have been updated to support multi-GPU caching.
- `--skip_cache_check` option is added to each training script.
- When specified, the consistency check of the cache file `*.npz` contents (e.g., image size and flip for latents, mask for Text Encoder outputs) is skipped.
- Specify this option if you have a large number of cache files and the consistency check takes time.
- Even if this option is specified, the cache will be created if the file does not exist.
- `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead.
Oct 12, 2024 (update 1):
- [Experimental] FLUX.1 fine-tuning and LoRA training now support "FLUX.1 __compact__" models.

View File

@@ -59,7 +59,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)

View File

@@ -57,6 +57,10 @@ def train(args):
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
# temporary: backward compatibility for deprecated options. remove in the future
if not args.skip_cache_check:
args.skip_cache_check = args.skip_latents_validity_check
# assert (
# not args.weighted_captions
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
@@ -81,7 +85,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents:
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
@@ -142,7 +146,7 @@ def train(args):
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
)
)
t5xxl_max_token_length = (
@@ -229,7 +233,7 @@ def train(args):
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator.is_main_process)
train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator)
# cache sample prompt's embeddings to free text encoder's memory
if args.sample_prompts is not None:
@@ -952,7 +956,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
help="skip latents validity check / latentsの正当性チェックをスキップする",
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
)
parser.add_argument(
"--blocks_to_swap",

View File

@@ -188,8 +188,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
None,
False,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
@@ -222,7 +222,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
text_encoders[1].to(weight_dtype)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process)
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# cache sample prompts
if args.sample_prompts is not None:

View File

@@ -325,7 +325,7 @@ class TextEncoderOutputsCachingStrategy:
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
batch_size: Optional[int],
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
is_weighted: bool = False,

View File

@@ -3,6 +3,7 @@
import argparse
import ast
import asyncio
from concurrent.futures import Future, ThreadPoolExecutor
import datetime
import importlib
import json
@@ -31,6 +32,7 @@ import hashlib
import subprocess
from io import BytesIO
import toml
# from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
@@ -1029,7 +1031,7 @@ class BaseDataset(torch.utils.data.Dataset):
]
)
def new_cache_latents(self, model: Any, is_main_process: bool):
def new_cache_latents(self, model: Any, accelerator: Accelerator):
r"""
a brand new method to cache latents. This method caches latents with caching strategy.
normal cache_latents method is used by default, but this method is used when caching strategy is specified.
@@ -1057,60 +1059,77 @@ class BaseDataset(torch.utils.data.Dataset):
and self.random_crop == other.random_crop
)
batches: List[Tuple[Condition, List[ImageInfo]]] = []
batch: List[ImageInfo] = []
current_condition = None
logger.info("checking cache validity...")
for info in tqdm(image_infos):
subset = self.image_to_subset[info.image_key]
# support multiple-gpus
num_processes = accelerator.num_processes
process_index = accelerator.process_index
if info.latents_npz is not None: # fine tuning dataset
continue
# define a function to submit a batch to cache
def submit_batch(batch, cond):
for info in batch:
if info.image is not None and isinstance(info.image, Future):
info.image = info.image.result() # future to image
caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop)
# check disk cache exists and size of latents
if caching_strategy.cache_to_disk:
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
if not is_main_process: # prepare for multi-gpu, only store to info
# define ThreadPoolExecutor to load images in parallel
max_workers = min(os.cpu_count(), len(image_infos))
max_workers = max(1, max_workers // num_processes) # consider multi-gpu
max_workers = min(max_workers, caching_strategy.batch_size) # max_workers should be less than batch_size
executor = ThreadPoolExecutor(max_workers)
try:
# iterate images
logger.info("caching latents...")
for i, info in enumerate(tqdm(image_infos)):
subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None: # fine tuning dataset
continue
cache_available = caching_strategy.is_disk_cached_latents_expected(
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
)
if cache_available: # do not add to batch
continue
# check disk cache exists and size of latents
if caching_strategy.cache_to_disk:
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
if len(batch) > 0 and current_condition != condition:
batches.append((current_condition, batch))
batch = []
# if the modulo of num_processes is not equal to process_index, skip caching
# this makes each process cache different latents
if i % num_processes != process_index:
continue
batch.append(info)
current_condition = condition
# print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}")
# if number of data in batch is enough, flush the batch
if len(batch) >= caching_strategy.batch_size:
batches.append((current_condition, batch))
batch = []
current_condition = None
cache_available = caching_strategy.is_disk_cached_latents_expected(
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
)
if cache_available: # do not add to batch
continue
if len(batch) > 0:
batches.append((current_condition, batch))
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
if len(batch) > 0 and current_condition != condition:
submit_batch(batch, current_condition)
batch = []
# if cache to disk, don't cache latents in non-main process, set to info only
if caching_strategy.cache_to_disk and not is_main_process:
return
if info.image is None:
# load image in parallel
info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask)
if len(batches) == 0:
logger.info("no latents to cache")
return
batch.append(info)
current_condition = condition
# iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded
logger.info("caching latents...")
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
# if number of data in batch is enough, flush the batch
if len(batch) >= caching_strategy.batch_size:
submit_batch(batch, current_condition)
batch = []
current_condition = None
if len(batch) > 0:
submit_batch(batch, current_condition)
finally:
executor.shutdown()
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
@@ -1187,7 +1206,7 @@ class BaseDataset(torch.utils.data.Dataset):
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator):
r"""
a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy.
"""
@@ -1202,15 +1221,25 @@ class BaseDataset(torch.utils.data.Dataset):
# split by resolution
batches = []
batch = []
logger.info("checking cache validity...")
for info in tqdm(image_infos):
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
# check disk cache exists and size of latents
# support multiple-gpus
num_processes = accelerator.num_processes
process_index = accelerator.process_index
logger.info("checking cache validity...")
for i, info in enumerate(tqdm(image_infos)):
# check disk cache exists and size of text encoder outputs
if caching_strategy.cache_to_disk:
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability
# if the modulo of num_processes is not equal to process_index, skip caching
# this makes each process cache different text encoder outputs
if i % num_processes != process_index:
continue
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
if cache_available or not is_main_process: # do not add to batch
if cache_available: # do not add to batch
continue
batch.append(info)
@@ -2327,8 +2356,8 @@ class ControlNetDataset(BaseDataset):
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
def new_cache_latents(self, model: Any, is_main_process: bool):
return self.dreambooth_dataset_delegate.new_cache_latents(model, is_main_process)
def new_cache_latents(self, model: Any, accelerator: Accelerator):
return self.dreambooth_dataset_delegate.new_cache_latents(model, accelerator)
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process)
@@ -2432,10 +2461,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
logger.info(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)
def new_cache_latents(self, model: Any, is_main_process: bool):
def new_cache_latents(self, model: Any, accelerator: Accelerator):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.new_cache_latents(model, is_main_process)
dataset.new_cache_latents(model, accelerator)
accelerator.wait_for_everyone()
def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
@@ -2453,10 +2483,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size
)
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.new_cache_text_encoder_outputs(models, is_main_process)
dataset.new_cache_text_encoder_outputs(models, accelerator)
accelerator.wait_for_everyone()
def set_caching_mode(self, caching_mode):
for dataset in self.datasets:
@@ -4054,15 +4085,18 @@ def verify_command_line_training_args(args: argparse.Namespace):
)
def enable_high_vram(args: argparse.Namespace):
if args.highvram:
logger.info("highvram is enabled / highvramが有効です")
global HIGH_VRAM
HIGH_VRAM = True
def verify_training_args(args: argparse.Namespace):
r"""
Verify training arguments. Also reflect highvram option to global variable
学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する
"""
if args.highvram:
print("highvram is enabled / highvramが有効です")
global HIGH_VRAM
HIGH_VRAM = True
enable_high_vram(args)
if args.v_parameterization and not args.v2:
logger.warning(
@@ -4226,6 +4260,12 @@ def add_dataset_arguments(
action="store_true",
help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheするaugmentationは使用不可",
)
parser.add_argument(
"--skip_cache_check",
action="store_true",
help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist"
" / cacheの内容の検証をスキップするlatentとテキストエンコーダの出力。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる",
)
parser.add_argument(
"--enable_bucket",
action="store_true",
@@ -5100,15 +5140,24 @@ def prepare_accelerator(args: argparse.Namespace):
dynamo_backend = args.dynamo_backend
kwargs_handlers = [
InitProcessGroupKwargs(
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None,
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None
) if torch.cuda.device_count() > 1 else None,
DistributedDataParallelKwargs(
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view,
static_graph=args.ddp_static_graph
) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None
(
InitProcessGroupKwargs(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method=(
"env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None
),
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None,
)
if torch.cuda.device_count() > 1
else None
),
(
DistributedDataParallelKwargs(
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
)
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
else None
),
]
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)

View File

@@ -57,6 +57,10 @@ def train(args):
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
# temporary: backward compatibility for deprecated options. remove in the future
if not args.skip_cache_check:
args.skip_cache_check = args.skip_latents_validity_check
assert (
not args.weighted_captions
), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
@@ -103,7 +107,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents:
latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
@@ -312,7 +316,7 @@ def train(args):
text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
False,
args.skip_cache_check,
train_clip_g or train_clip_l or args.use_t5xxl_cache_only,
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
@@ -325,7 +329,7 @@ def train(args):
t5xxl.to(t5xxl_device, dtype=t5xxl_dtype)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process)
train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator)
# cache sample prompt's embeddings to free text encoder's memory
if args.sample_prompts is not None:
@@ -1052,7 +1056,12 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
help="skip latents validity check / latentsの正当性チェックをスキップする",
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
)
parser.add_argument(
"--skip_cache_check",
action="store_true",
help="skip cache (latents and text encoder outputs) check / キャッシュlatentsとtext encoder outputsのチェックをスキップする",
)
parser.add_argument(
"--num_last_block_to_freeze",

View File

@@ -131,7 +131,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
@@ -328,7 +328,7 @@ def train(args):
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process)
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
accelerator.wait_for_everyone()

View File

@@ -84,7 +84,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
@@ -230,7 +230,7 @@ def train(args):
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process)
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
accelerator.wait_for_everyone()

View File

@@ -93,7 +93,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
@@ -202,7 +202,7 @@ def train(args):
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process)
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
accelerator.wait_for_everyone()
@@ -431,7 +431,6 @@ def train(args):
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
# Text Encoder outputs are cached

View File

@@ -67,7 +67,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy
@@ -80,7 +80,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions
)
else:
return None
@@ -102,9 +102,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(
text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process
)
dataset.new_cache_text_encoder_outputs(text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator)
accelerator.wait_for_everyone()
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU

View File

@@ -49,7 +49,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy

View File

@@ -9,7 +9,7 @@ from accelerate.utils import set_seed
import torch
from tqdm import tqdm
from library import config_util
from library import config_util, flux_train_utils, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl
from library import train_util
from library import sdxl_train_util
from library.config_util import (
@@ -17,42 +17,74 @@ from library.config_util import (
BlueprintGenerator,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argparse.Namespace) -> None:
if is_flux:
_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path)
else:
is_schnell = False
if is_sd:
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
elif is_sdxl:
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
else:
if args.t5xxl_max_token_length is None:
if is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
else:
t5xxl_max_token_length = args.t5xxl_max_token_length
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
def cache_to_disk(args: argparse.Namespace) -> None:
setup_logging(args, reset=True)
train_util.prepare_dataset_args(args, True)
train_util.enable_high_vram(args)
# check cache latents arg
assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
# assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
args.cache_latents = True
args.cache_latents_to_disk = True
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# tokenizerを準備するdatasetを動かすために必要
if args.sdxl:
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenizers = [tokenizer1, tokenizer2]
is_sd = not args.sdxl and not args.flux
is_sdxl = args.sdxl
is_flux = args.flux
set_tokenize_strategy(is_sd, is_sdxl, is_flux, args)
if is_sd or is_sdxl:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(is_sd, True, args.vae_batch_size, args.skip_cache_check)
else:
tokenizer = train_util.load_tokenizer(args)
tokenizers = [tokenizer]
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(True, args.vae_batch_size, args.skip_cache_check)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
use_user_config = args.dataset_config is not None
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if use_user_config:
logger.info(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
@@ -83,17 +115,11 @@ def cache_to_disk(args: argparse.Namespace) -> None:
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
# datasetのcache_latentsを呼ばなければ、生の画像が返る
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args)
# acceleratorを準備する
logger.info("prepare accelerator")
@@ -106,72 +132,27 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# モデルを読み込む
logger.info("load model")
if args.sdxl:
if is_sd:
_, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
elif is_sdxl:
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
else:
_, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
vae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
if is_sd or is_sdxl:
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
# dataloaderを準備する
train_dataset_group.set_caching_mode("latents")
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# acceleratorを使ってモデルを準備するマルチGPUで使えるようになるはず
train_dataloader = accelerator.prepare(train_dataloader)
# データ取得のためのループ
for batch in tqdm(train_dataloader):
b_size = len(batch["images"])
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
flip_aug = batch["flip_aug"]
alpha_mask = batch["alpha_mask"]
random_crop = batch["random_crop"]
bucket_reso = batch["bucket_reso"]
# バッチを分割して処理する
for i in range(0, b_size, vae_batch_size):
images = batch["images"][i : i + vae_batch_size]
absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
image_infos = []
for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
image_info.image = image
image_info.bucket_reso = bucket_reso
image_info.resized_size = resized_size
image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz"
if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(
image_info.bucket_reso, image_info.latents_npz, flip_aug, alpha_mask
):
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
continue
image_infos.append(image_info)
if len(image_infos) > 0:
train_util.cache_batch_latents(vae, True, image_infos, flip_aug, alpha_mask, random_crop)
# cache latents with dataset
# TODO use DataLoader to speed up
train_dataset_group.new_cache_latents(vae, accelerator)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
accelerator.print(f"Finished caching latents to disk.")
def setup_parser() -> argparse.ArgumentParser:
@@ -181,8 +162,12 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_masked_loss_arguments(parser)
config_util.add_config_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する")
parser.add_argument(
"--no_half_vae",
action="store_true",
@@ -191,7 +176,8 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--skip_existing",
action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check."
" / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。",
)
return parser

View File

@@ -9,55 +9,69 @@ from accelerate.utils import set_seed
import torch
from tqdm import tqdm
from library import config_util
from library import (
config_util,
flux_train_utils,
flux_utils,
sdxl_model_util,
strategy_base,
strategy_flux,
strategy_sd,
strategy_sdxl,
)
from library import train_util
from library import sdxl_train_util
from library import utils
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging, add_logging_arguments
from cache_latents import set_tokenize_strategy
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
setup_logging(args, reset=True)
train_util.prepare_dataset_args(args, True)
train_util.enable_high_vram(args)
# check cache arg
assert (
args.cache_text_encoder_outputs_to_disk
), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
# できるだけ準備はしておくが今のところSDXLのみしか動かない
assert (
args.sdxl
), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です"
args.cache_text_encoder_outputs = True
args.cache_text_encoder_outputs_to_disk = True
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# tokenizerを準備するdatasetを動かすために必要
if args.sdxl:
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenizers = [tokenizer1, tokenizer2]
else:
tokenizer = train_util.load_tokenizer(args)
tokenizers = [tokenizer]
is_sd = not args.sdxl and not args.flux
is_sdxl = args.sdxl
is_flux = args.flux
assert (
is_sdxl or is_flux
), "Cache text encoder outputs to disk is only supported for SDXL and FLUX models / テキストエンコーダ出力のディスクキャッシュはSDXLまたはFLUXでのみ有効です"
assert (
is_sdxl or args.weighted_captions is None
), "Weighted captions are only supported for SDXL models / 重み付きキャプションはSDXLモデルでのみ有効です"
set_tokenize_strategy(is_sd, is_sdxl, is_flux, args)
# データセットを準備する
use_user_config = args.dataset_config is not None
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if use_user_config:
logger.info(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
@@ -88,15 +102,11 @@ def cache_to_disk(args: argparse.Namespace) -> None:
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args)
# acceleratorを準備する
logger.info("prepare accelerator")
@@ -105,69 +115,71 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
t5xxl_dtype = utils.str_to_dtype(args.t5xxl_dtype, weight_dtype)
# モデルを読み込む
logger.info("load model")
if args.sdxl:
(_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
if is_sdxl:
_, text_encoder1, text_encoder2, _, _, _, _ = sdxl_train_util.load_target_model(
args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype
)
text_encoder1.to(accelerator.device, weight_dtype)
text_encoder2.to(accelerator.device, weight_dtype)
text_encoders = [text_encoder1, text_encoder2]
else:
text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
text_encoders = [text_encoder1]
clip_l = flux_utils.load_clip_l(
args.clip_l, weight_dtype, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors
)
t5xxl = flux_utils.load_t5xxl(args.t5xxl, None, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors)
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
elif t5xxl.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 T5XXL model")
if t5xxl_dtype != t5xxl_dtype:
if t5xxl.dtype == torch.float8_e4m3fn and t5xxl_dtype.itemsize() >= 2:
logger.warning(
"The loaded model is fp8, but the specified T5XXL dtype is larger than fp8. This may cause a performance drop."
" / ロードされたモデルはfp8ですが、指定されたT5XXLのdtypeがfp8より高精度です。精度低下が発生する可能性があります。"
)
logger.info(f"Casting T5XXL model to {t5xxl_dtype}")
t5xxl.to(t5xxl_dtype)
text_encoders = [clip_l, t5xxl]
for text_encoder in text_encoders:
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
# dataloaderを準備する
train_dataset_group.set_caching_mode("text")
# build text encoder outputs caching strategy
if is_sdxl:
text_encoder_outputs_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions
)
else:
text_encoder_outputs_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=False,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy)
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
# build text encoding strategy
if is_sdxl:
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
else:
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# acceleratorを使ってモデルを準備するマルチGPUで使えるようになるはず
train_dataloader = accelerator.prepare(train_dataloader)
# データ取得のためのループ
for batch in tqdm(train_dataloader):
absolute_paths = batch["absolute_paths"]
input_ids1_list = batch["input_ids1_list"]
input_ids2_list = batch["input_ids2_list"]
image_infos = []
for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list):
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
image_info
if args.skip_existing:
if os.path.exists(image_info.text_encoder_outputs_npz):
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
continue
image_info.input_ids1 = input_ids1
image_info.input_ids2 = input_ids2
image_infos.append(image_info)
if len(image_infos) > 0:
b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos])
train_util.cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype
)
# cache text encoder outputs
train_dataset_group.new_cache_text_encoder_outputs(text_encoders, accelerator)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
accelerator.print(f"Finished caching text encoder outputs to disk.")
def setup_parser() -> argparse.ArgumentParser:
@@ -177,13 +189,29 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_masked_loss_arguments(parser)
config_util.add_config_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する")
parser.add_argument(
"--t5xxl_dtype",
type=str,
default=None,
help="T5XXL model dtype, default: None (use mixed precision dtype) / T5XXLモデルのdtype, デフォルト: None (mixed precisionのdtypeを使用)",
)
parser.add_argument(
"--skip_existing",
action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check."
" / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。",
)
parser.add_argument(
"--weighted_captions",
action="store_true",
default=False,
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
)
return parser

View File

@@ -64,7 +64,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)

View File

@@ -116,7 +116,7 @@ class NetworkTrainer:
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
True, args.cache_latents_to_disk, args.vae_batch_size, False
True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy
@@ -384,7 +384,7 @@ class NetworkTrainer:
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
train_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)

View File

@@ -114,7 +114,7 @@ class TextualInversionTrainer:
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
True, args.cache_latents_to_disk, args.vae_batch_size, False
True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy