mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix bug in FLUX multi GPU training
This commit is contained in:
@@ -1104,10 +1104,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
|
||||
batch_size = caching_strategy.batch_size or self.batch_size
|
||||
|
||||
# if cache to disk, don't cache TE outputs in non-main process
|
||||
if caching_strategy.cache_to_disk and not is_main_process:
|
||||
return
|
||||
|
||||
logger.info("caching Text Encoder outputs with caching strategy.")
|
||||
image_infos = list(self.image_data.values())
|
||||
|
||||
@@ -1120,9 +1116,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if caching_strategy.cache_to_disk:
|
||||
info.text_encoder_outputs_npz = te_out_npz
|
||||
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process
|
||||
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
|
||||
if cache_available: # do not add to batch
|
||||
if cache_available or not is_main_process: # do not add to batch
|
||||
continue
|
||||
|
||||
batch.append(info)
|
||||
@@ -2638,7 +2634,7 @@ def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset:
|
||||
return train_dataset_group
|
||||
|
||||
|
||||
def load_image(image_path, alpha=False):
|
||||
def load_image(image_path, alpha=False):
|
||||
try:
|
||||
with Image.open(image_path) as image:
|
||||
if alpha:
|
||||
|
||||
Reference in New Issue
Block a user