From 8f2ba27869e4c5b9225a309aeed275a47d8eed6a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Jun 2024 20:36:22 +0900 Subject: [PATCH] support text_encoder_batch_size for caching --- library/sd3_train_utils.py | 7 +++++++ library/train_util.py | 14 ++++++++++---- sd3_train.py | 1 + 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 4e45871f..70c83c0b 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -173,6 +173,13 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): action="store_true", help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", ) + parser.add_argument( + "--text_encoder_batch_size", + type=int, + default=None, + help="text encoder batch size (default: None, use dataset's batch size)" + + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", + ) parser.add_argument( "--disable_mmap_load_safetensors", action="store_true", diff --git a/library/train_util.py b/library/train_util.py index c67e8737..96d32e3b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1054,7 +1054,7 @@ class BaseDataset(torch.utils.data.Dataset): # same as above, but for SD3 def cache_text_encoder_outputs_sd3( - self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True + self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None ): return self.cache_text_encoder_outputs_common( [tokenizer], @@ -1065,6 +1065,7 @@ class BaseDataset(torch.utils.data.Dataset): cache_to_disk, is_main_process, TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3, + batch_size, ) def cache_text_encoder_outputs_common( @@ -1077,10 +1078,15 @@ class BaseDataset(torch.utils.data.Dataset): cache_to_disk=False, is_main_process=True, file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX, + batch_size=None, ): # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching text encoder outputs.") + + if batch_size is None: + batch_size = self.batch_size + image_infos = list(self.image_data.values()) logger.info("checking cache existence...") @@ -1122,7 +1128,7 @@ class BaseDataset(torch.utils.data.Dataset): l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption) batch.append((info, l_tokens, g_tokens, t5_tokens)) - if len(batch) >= self.batch_size: + if len(batch) >= batch_size: batches.append(batch) batch = [] @@ -2209,12 +2215,12 @@ class DatasetGroup(torch.utils.data.ConcatDataset): dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) def cache_text_encoder_outputs_sd3( - self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True + self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None ): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") dataset.cache_text_encoder_outputs_sd3( - tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process + tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size ) def set_caching_mode(self, caching_mode): diff --git a/sd3_train.py b/sd3_train.py index 0721b2ae..8216a62b 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -254,6 +254,7 @@ def train(args): (None, None, None), args.cache_text_encoder_outputs_to_disk, accelerator.is_main_process, + args.text_encoder_batch_size, ) accelerator.wait_for_everyone()