Fix bug in FLUX multi GPU training

This commit is contained in:
kohya-ss
2024-08-22 12:37:41 +09:00
parent e1cd19c0c0
commit 98c91a7625
8 changed files with 156 additions and 38 deletions

View File

@@ -1104,10 +1104,6 @@ class BaseDataset(torch.utils.data.Dataset):
caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
batch_size = caching_strategy.batch_size or self.batch_size
# if cache to disk, don't cache TE outputs in non-main process
if caching_strategy.cache_to_disk and not is_main_process:
return
logger.info("caching Text Encoder outputs with caching strategy.")
image_infos = list(self.image_data.values())
@@ -1120,9 +1116,9 @@ class BaseDataset(torch.utils.data.Dataset):
# check disk cache exists and size of latents
if caching_strategy.cache_to_disk:
info.text_encoder_outputs_npz = te_out_npz
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
if cache_available: # do not add to batch
if cache_available or not is_main_process: # do not add to batch
continue
batch.append(info)
@@ -2638,7 +2634,7 @@ def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset:
return train_dataset_group
def load_image(image_path, alpha=False):
def load_image(image_path, alpha=False):
try:
with Image.open(image_path) as image:
if alpha: