support text_encoder_batch_size for caching

This commit is contained in:
Kohya S
2024-06-26 20:36:22 +09:00
parent 4802e4aaec
commit 8f2ba27869
3 changed files with 18 additions and 4 deletions

View File

@@ -173,6 +173,13 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
action="store_true", action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
) )
parser.add_argument(
"--text_encoder_batch_size",
type=int,
default=None,
help="text encoder batch size (default: None, use dataset's batch size)"
+ " / text encoderのバッチサイズデフォルト: None, データセットのバッチサイズを使用)",
)
parser.add_argument( parser.add_argument(
"--disable_mmap_load_safetensors", "--disable_mmap_load_safetensors",
action="store_true", action="store_true",

View File

@@ -1054,7 +1054,7 @@ class BaseDataset(torch.utils.data.Dataset):
# same as above, but for SD3 # same as above, but for SD3
def cache_text_encoder_outputs_sd3( def cache_text_encoder_outputs_sd3(
self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None
): ):
return self.cache_text_encoder_outputs_common( return self.cache_text_encoder_outputs_common(
[tokenizer], [tokenizer],
@@ -1065,6 +1065,7 @@ class BaseDataset(torch.utils.data.Dataset):
cache_to_disk, cache_to_disk,
is_main_process, is_main_process,
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3, TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3,
batch_size,
) )
def cache_text_encoder_outputs_common( def cache_text_encoder_outputs_common(
@@ -1077,10 +1078,15 @@ class BaseDataset(torch.utils.data.Dataset):
cache_to_disk=False, cache_to_disk=False,
is_main_process=True, is_main_process=True,
file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX, file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
batch_size=None,
): ):
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching text encoder outputs.") logger.info("caching text encoder outputs.")
if batch_size is None:
batch_size = self.batch_size
image_infos = list(self.image_data.values()) image_infos = list(self.image_data.values())
logger.info("checking cache existence...") logger.info("checking cache existence...")
@@ -1122,7 +1128,7 @@ class BaseDataset(torch.utils.data.Dataset):
l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption) l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption)
batch.append((info, l_tokens, g_tokens, t5_tokens)) batch.append((info, l_tokens, g_tokens, t5_tokens))
if len(batch) >= self.batch_size: if len(batch) >= batch_size:
batches.append(batch) batches.append(batch)
batch = [] batch = []
@@ -2209,12 +2215,12 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
def cache_text_encoder_outputs_sd3( def cache_text_encoder_outputs_sd3(
self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None
): ):
for i, dataset in enumerate(self.datasets): for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]") logger.info(f"[Dataset {i}]")
dataset.cache_text_encoder_outputs_sd3( dataset.cache_text_encoder_outputs_sd3(
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size
) )
def set_caching_mode(self, caching_mode): def set_caching_mode(self, caching_mode):

View File

@@ -254,6 +254,7 @@ def train(args):
(None, None, None), (None, None, None),
args.cache_text_encoder_outputs_to_disk, args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process, accelerator.is_main_process,
args.text_encoder_batch_size,
) )
accelerator.wait_for_everyone() accelerator.wait_for_everyone()