From 392e8dedd84e469b125e2935e3ecf02e6270a5b2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 21:14:11 +0900 Subject: [PATCH] fix flip_aug, alpha_mask, random_crop issue in caching in caching strategy --- library/train_util.py | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 319337a4..17dd447e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -993,9 +993,26 @@ class BaseDataset(torch.utils.data.Dataset): # sort by resolution image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) - # split by resolution - batches = [] - batch = [] + # split by resolution and some conditions + class Condition: + def __init__(self, reso, flip_aug, alpha_mask, random_crop): + self.reso = reso + self.flip_aug = flip_aug + self.alpha_mask = alpha_mask + self.random_crop = random_crop + + def __eq__(self, other): + return ( + self.reso == other.reso + and self.flip_aug == other.flip_aug + and self.alpha_mask == other.alpha_mask + and self.random_crop == other.random_crop + ) + + batches: List[Tuple[Condition, List[ImageInfo]]] = [] + batch: List[ImageInfo] = [] + current_condition = None + logger.info("checking cache validity...") for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] @@ -1016,20 +1033,23 @@ class BaseDataset(torch.utils.data.Dataset): if cache_available: # do not add to batch continue - # if last member of batch has different resolution, flush the batch - if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: - batches.append(batch) + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + batches.append((current_condition, batch)) batch = [] batch.append(info) + current_condition = condition # if number of data in batch is enough, flush the batch if len(batch) >= caching_strategy.batch_size: - batches.append(batch) + batches.append((current_condition, batch)) batch = [] + current_condition = None if len(batch) > 0: - batches.append(batch) + batches.append((current_condition, batch)) # if cache to disk, don't cache latents in non-main process, set to info only if caching_strategy.cache_to_disk and not is_main_process: @@ -1041,9 +1061,8 @@ class BaseDataset(torch.utils.data.Dataset): # iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded logger.info("caching latents...") - for batch in tqdm(batches, smoothing=1, total=len(batches)): - # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) - caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): + caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと