mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support text_encoder_batch_size for caching
This commit is contained in:
@@ -1054,7 +1054,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# same as above, but for SD3
|
||||
def cache_text_encoder_outputs_sd3(
|
||||
self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True
|
||||
self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None
|
||||
):
|
||||
return self.cache_text_encoder_outputs_common(
|
||||
[tokenizer],
|
||||
@@ -1065,6 +1065,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
cache_to_disk,
|
||||
is_main_process,
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3,
|
||||
batch_size,
|
||||
)
|
||||
|
||||
def cache_text_encoder_outputs_common(
|
||||
@@ -1077,10 +1078,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
cache_to_disk=False,
|
||||
is_main_process=True,
|
||||
file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
|
||||
batch_size=None,
|
||||
):
|
||||
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
|
||||
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
logger.info("caching text encoder outputs.")
|
||||
|
||||
if batch_size is None:
|
||||
batch_size = self.batch_size
|
||||
|
||||
image_infos = list(self.image_data.values())
|
||||
|
||||
logger.info("checking cache existence...")
|
||||
@@ -1122,7 +1128,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption)
|
||||
batch.append((info, l_tokens, g_tokens, t5_tokens))
|
||||
|
||||
if len(batch) >= self.batch_size:
|
||||
if len(batch) >= batch_size:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
@@ -2209,12 +2215,12 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
|
||||
|
||||
def cache_text_encoder_outputs_sd3(
|
||||
self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True
|
||||
self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None
|
||||
):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.cache_text_encoder_outputs_sd3(
|
||||
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process
|
||||
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size
|
||||
)
|
||||
|
||||
def set_caching_mode(self, caching_mode):
|
||||
|
||||
Reference in New Issue
Block a user