mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support text_encoder_batch_size for caching
This commit is contained in:
@@ -173,6 +173,13 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text_encoder_batch_size",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="text encoder batch size (default: None, use dataset's batch size)"
|
||||||
|
+ " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable_mmap_load_safetensors",
|
"--disable_mmap_load_safetensors",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -1054,7 +1054,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
# same as above, but for SD3
|
# same as above, but for SD3
|
||||||
def cache_text_encoder_outputs_sd3(
|
def cache_text_encoder_outputs_sd3(
|
||||||
self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True
|
self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None
|
||||||
):
|
):
|
||||||
return self.cache_text_encoder_outputs_common(
|
return self.cache_text_encoder_outputs_common(
|
||||||
[tokenizer],
|
[tokenizer],
|
||||||
@@ -1065,6 +1065,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
cache_to_disk,
|
cache_to_disk,
|
||||||
is_main_process,
|
is_main_process,
|
||||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3,
|
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3,
|
||||||
|
batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def cache_text_encoder_outputs_common(
|
def cache_text_encoder_outputs_common(
|
||||||
@@ -1077,10 +1078,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
cache_to_disk=False,
|
cache_to_disk=False,
|
||||||
is_main_process=True,
|
is_main_process=True,
|
||||||
file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
|
file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
|
||||||
|
batch_size=None,
|
||||||
):
|
):
|
||||||
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
|
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
|
||||||
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||||
logger.info("caching text encoder outputs.")
|
logger.info("caching text encoder outputs.")
|
||||||
|
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = self.batch_size
|
||||||
|
|
||||||
image_infos = list(self.image_data.values())
|
image_infos = list(self.image_data.values())
|
||||||
|
|
||||||
logger.info("checking cache existence...")
|
logger.info("checking cache existence...")
|
||||||
@@ -1122,7 +1128,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption)
|
l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption)
|
||||||
batch.append((info, l_tokens, g_tokens, t5_tokens))
|
batch.append((info, l_tokens, g_tokens, t5_tokens))
|
||||||
|
|
||||||
if len(batch) >= self.batch_size:
|
if len(batch) >= batch_size:
|
||||||
batches.append(batch)
|
batches.append(batch)
|
||||||
batch = []
|
batch = []
|
||||||
|
|
||||||
@@ -2209,12 +2215,12 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
|
dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
|
||||||
|
|
||||||
def cache_text_encoder_outputs_sd3(
|
def cache_text_encoder_outputs_sd3(
|
||||||
self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True
|
self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None
|
||||||
):
|
):
|
||||||
for i, dataset in enumerate(self.datasets):
|
for i, dataset in enumerate(self.datasets):
|
||||||
logger.info(f"[Dataset {i}]")
|
logger.info(f"[Dataset {i}]")
|
||||||
dataset.cache_text_encoder_outputs_sd3(
|
dataset.cache_text_encoder_outputs_sd3(
|
||||||
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process
|
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_caching_mode(self, caching_mode):
|
def set_caching_mode(self, caching_mode):
|
||||||
|
|||||||
@@ -254,6 +254,7 @@ def train(args):
|
|||||||
(None, None, None),
|
(None, None, None),
|
||||||
args.cache_text_encoder_outputs_to_disk,
|
args.cache_text_encoder_outputs_to_disk,
|
||||||
accelerator.is_main_process,
|
accelerator.is_main_process,
|
||||||
|
args.text_encoder_batch_size,
|
||||||
)
|
)
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user