experimental support for multi-gpus latents caching

This commit is contained in:
kohya-ss
2024-09-26 22:19:56 +09:00
parent 3ebb65f945
commit 9249d00311
2 changed files with 17 additions and 12 deletions

View File

@@ -981,7 +981,7 @@ class BaseDataset(torch.utils.data.Dataset):
]
)
def new_cache_latents(self, model: Any, is_main_process: bool):
def new_cache_latents(self, model: Any, accelerator: Accelerator):
r"""
a brand new method to cache latents. This method caches latents with caching strategy.
normal cache_latents method is used by default, but this method is used when caching strategy is specified.
@@ -1013,8 +1013,12 @@ class BaseDataset(torch.utils.data.Dataset):
batch: List[ImageInfo] = []
current_condition = None
# support multiple-gpus
num_processes = accelerator.num_processes
process_index = accelerator.process_index
logger.info("checking cache validity...")
for info in tqdm(image_infos):
for i, info in enumerate(tqdm(image_infos)):
subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None: # fine tuning dataset
@@ -1024,9 +1028,14 @@ class BaseDataset(torch.utils.data.Dataset):
if caching_strategy.cache_to_disk:
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
if not is_main_process: # prepare for multi-gpu, only store to info
# if the modulo of num_processes is not equal to process_index, skip caching
# this makes each process cache different latents
if i % num_processes != process_index:
continue
print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}")
cache_available = caching_strategy.is_disk_cached_latents_expected(
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
)
@@ -1051,10 +1060,6 @@ class BaseDataset(torch.utils.data.Dataset):
if len(batch) > 0:
batches.append((current_condition, batch))
# if cache to disk, don't cache latents in non-main process, set to info only
if caching_strategy.cache_to_disk and not is_main_process:
return
if len(batches) == 0:
logger.info("no latents to cache")
return
@@ -2258,8 +2263,8 @@ class ControlNetDataset(BaseDataset):
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
def new_cache_latents(self, model: Any, is_main_process: bool):
return self.dreambooth_dataset_delegate.new_cache_latents(model, is_main_process)
def new_cache_latents(self, model: Any, accelerator: Accelerator):
return self.dreambooth_dataset_delegate.new_cache_latents(model, accelerator)
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process)
@@ -2363,10 +2368,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
logger.info(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)
def new_cache_latents(self, model: Any, is_main_process: bool):
def new_cache_latents(self, model: Any, accelerator: Accelerator):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.new_cache_latents(model, is_main_process)
dataset.new_cache_latents(model, accelerator)
def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True