mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
experimental support for multi-gpus latents caching
This commit is contained in:
@@ -981,7 +981,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
]
|
||||
)
|
||||
|
||||
def new_cache_latents(self, model: Any, is_main_process: bool):
|
||||
def new_cache_latents(self, model: Any, accelerator: Accelerator):
|
||||
r"""
|
||||
a brand new method to cache latents. This method caches latents with caching strategy.
|
||||
normal cache_latents method is used by default, but this method is used when caching strategy is specified.
|
||||
@@ -1013,8 +1013,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
batch: List[ImageInfo] = []
|
||||
current_condition = None
|
||||
|
||||
# support multiple-gpus
|
||||
num_processes = accelerator.num_processes
|
||||
process_index = accelerator.process_index
|
||||
|
||||
logger.info("checking cache validity...")
|
||||
for info in tqdm(image_infos):
|
||||
for i, info in enumerate(tqdm(image_infos)):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
|
||||
if info.latents_npz is not None: # fine tuning dataset
|
||||
@@ -1024,9 +1028,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if caching_strategy.cache_to_disk:
|
||||
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
|
||||
if not is_main_process: # prepare for multi-gpu, only store to info
|
||||
|
||||
# if the modulo of num_processes is not equal to process_index, skip caching
|
||||
# this makes each process cache different latents
|
||||
if i % num_processes != process_index:
|
||||
continue
|
||||
|
||||
print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}")
|
||||
|
||||
cache_available = caching_strategy.is_disk_cached_latents_expected(
|
||||
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
|
||||
)
|
||||
@@ -1051,10 +1060,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if len(batch) > 0:
|
||||
batches.append((current_condition, batch))
|
||||
|
||||
# if cache to disk, don't cache latents in non-main process, set to info only
|
||||
if caching_strategy.cache_to_disk and not is_main_process:
|
||||
return
|
||||
|
||||
if len(batches) == 0:
|
||||
logger.info("no latents to cache")
|
||||
return
|
||||
@@ -2258,8 +2263,8 @@ class ControlNetDataset(BaseDataset):
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
||||
|
||||
def new_cache_latents(self, model: Any, is_main_process: bool):
|
||||
return self.dreambooth_dataset_delegate.new_cache_latents(model, is_main_process)
|
||||
def new_cache_latents(self, model: Any, accelerator: Accelerator):
|
||||
return self.dreambooth_dataset_delegate.new_cache_latents(model, accelerator)
|
||||
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
|
||||
return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process)
|
||||
@@ -2363,10 +2368,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)
|
||||
|
||||
def new_cache_latents(self, model: Any, is_main_process: bool):
|
||||
def new_cache_latents(self, model: Any, accelerator: Accelerator):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.new_cache_latents(model, is_main_process)
|
||||
dataset.new_cache_latents(model, accelerator)
|
||||
|
||||
def cache_text_encoder_outputs(
|
||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||
|
||||
@@ -384,7 +384,7 @@ class NetworkTrainer:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
Reference in New Issue
Block a user