diff --git a/library/train_util.py b/library/train_util.py index 3768b605..2ca662dc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -981,7 +981,7 @@ class BaseDataset(torch.utils.data.Dataset): ] ) - def new_cache_latents(self, model: Any, is_main_process: bool): + def new_cache_latents(self, model: Any, accelerator: Accelerator): r""" a brand new method to cache latents. This method caches latents with caching strategy. normal cache_latents method is used by default, but this method is used when caching strategy is specified. @@ -1013,8 +1013,12 @@ class BaseDataset(torch.utils.data.Dataset): batch: List[ImageInfo] = [] current_condition = None + # support multiple-gpus + num_processes = accelerator.num_processes + process_index = accelerator.process_index + logger.info("checking cache validity...") - for info in tqdm(image_infos): + for i, info in enumerate(tqdm(image_infos)): subset = self.image_to_subset[info.image_key] if info.latents_npz is not None: # fine tuning dataset @@ -1024,9 +1028,14 @@ class BaseDataset(torch.utils.data.Dataset): if caching_strategy.cache_to_disk: # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) - if not is_main_process: # prepare for multi-gpu, only store to info + + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different latents + if i % num_processes != process_index: continue + print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") + cache_available = caching_strategy.is_disk_cached_latents_expected( info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask ) @@ -1051,10 +1060,6 @@ class BaseDataset(torch.utils.data.Dataset): if len(batch) > 0: batches.append((current_condition, batch)) - # if cache to disk, don't cache latents in non-main process, set to info only - if caching_strategy.cache_to_disk and not is_main_process: - return - if len(batches) == 0: logger.info("no latents to cache") return @@ -2258,8 +2263,8 @@ class ControlNetDataset(BaseDataset): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) - def new_cache_latents(self, model: Any, is_main_process: bool): - return self.dreambooth_dataset_delegate.new_cache_latents(model, is_main_process) + def new_cache_latents(self, model: Any, accelerator: Accelerator): + return self.dreambooth_dataset_delegate.new_cache_latents(model, accelerator) def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process) @@ -2363,10 +2368,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset): logger.info(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) - def new_cache_latents(self, model: Any, is_main_process: bool): + def new_cache_latents(self, model: Any, accelerator: Accelerator): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.new_cache_latents(model, is_main_process) + dataset.new_cache_latents(model, accelerator) def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True diff --git a/train_network.py b/train_network.py index b24f89b1..7eb7aa49 100644 --- a/train_network.py +++ b/train_network.py @@ -384,7 +384,7 @@ class NetworkTrainer: vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device)