diff --git a/library/train_util.py b/library/train_util.py index b9d08f25..d4d646fc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -29,6 +29,7 @@ import hashlib import subprocess from io import BytesIO import toml +import copy from tqdm import tqdm @@ -164,6 +165,8 @@ class ImageInfo: self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + self.latent_cache_checked: bool = False + self.te_cache_checked: bool = False class BucketManager: @@ -653,6 +656,11 @@ class BaseDataset(torch.utils.data.Dataset): # caching self.caching_mode = None # None, 'latents', 'text' + # lists for incremental loading of regularization images + self.reg_infos = None + self.reg_infos_index = None + self.reg_randomize = False + def adjust_min_max_bucket_reso_by_steps( self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int ) -> Tuple[int, int]: @@ -684,6 +692,12 @@ class BaseDataset(torch.utils.data.Dataset): def set_seed(self, seed): self.seed = seed + def set_reg_randomize(self, reg_randomize = False): + self.reg_randomize = reg_randomize + + def incremental_reg_load(self, make_bucket = False): # Placeholder method, does nothing unless overridden in subclasses. + return + def set_caching_mode(self, mode): self.caching_mode = mode @@ -951,11 +965,14 @@ class BaseDataset(torch.utils.data.Dataset): if self.enable_bucket: self.bucket_info = {"buckets": {}} logger.info("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") + batch_count: int = 0 for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)): count = len(bucket) if count > 0: + batch_count += math.ceil(len(bucket) / self.batch_size) self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} - logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}") + logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}, batches: {int(math.ceil(len(bucket) / self.batch_size))}") + logger.info(f"Total batch count: {batch_count}") if len(img_ar_errors) == 0: mean_img_ar_error = 0 # avoid NaN @@ -967,6 +984,7 @@ class BaseDataset(torch.utils.data.Dataset): # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる self.buckets_indices: List[BucketBatchIndex] = [] + self.buckets_indices.clear() for bucket_index, bucket in enumerate(self.bucket_manager.buckets): batch_count = int(math.ceil(len(bucket) / self.batch_size)) for batch_index in range(batch_count): @@ -1025,6 +1043,10 @@ class BaseDataset(torch.utils.data.Dataset): logger.info("caching latents.") image_infos = list(self.image_data.values()) + image_infos = list(filter(lambda info: info.latent_cache_checked == False, image_infos)) + if len(image_infos) == 0: + logger.info("All images latents previously checked and cached. Skipping.") + return # sort by resolution image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) @@ -1054,11 +1076,17 @@ class BaseDataset(torch.utils.data.Dataset): subset = self.image_to_subset[info.image_key] if info.latents_npz is not None: # fine tuning dataset + info.latent_cache_checked = True + if self.reg_infos is not None and info.image_key in self.reg_infos: + self.reg_infos[info.image_key][0].latent_cache_checked = True continue # check disk cache exists and size of latents if cache_to_disk: info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" + info.latent_cache_checked = True + if self.reg_infos is not None and info.image_key in self.reg_infos: + self.reg_infos[info.image_key][0].latent_cache_checked = True if not is_main_process: # store to info only continue @@ -1094,6 +1122,15 @@ class BaseDataset(torch.utils.data.Dataset): logger.info("caching latents...") for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) + if self.reg_infos is not None: + for info in batch: + if info.image_key in self.reg_infos: + self.reg_infos[info.image_key][0].latents_npz = info.latents_npz + self.reg_infos[info.image_key][0].latents_original_size = info.latents_original_size + self.reg_infos[info.image_key][0].latents_crop_ltrb = info.latents_crop_ltrb + self.reg_infos[info.image_key][0].latents_crop_ltrb = info.latents_flipped + self.reg_infos[info.image_key][0].latents = info.latents + self.reg_infos[info.image_key][0].alpha_mask = info.alpha_mask # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する @@ -1107,6 +1144,10 @@ class BaseDataset(torch.utils.data.Dataset): # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching text encoder outputs.") image_infos = list(self.image_data.values()) + image_infos = list(filter(lambda info: info.te_cache_checked == False, image_infos)) + if len(image_infos) == 0: + logger.info("Text encoder outputs for all images previously checked and cached. Skipping.") + return logger.info("checking cache existence...") image_infos_to_cache = [] @@ -1115,6 +1156,10 @@ class BaseDataset(torch.utils.data.Dataset): if cache_to_disk: te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX info.text_encoder_outputs_npz = te_out_npz + info.te_cache_checked = True + if self.reg_infos is not None: + self.reg_infos[info.image_key][0].text_encoder_outputs_npz = te_out_npz + self.reg_infos[info.image_key][0].te_cache_checked = True if not is_main_process: # store to info only continue @@ -1157,6 +1202,14 @@ class BaseDataset(torch.utils.data.Dataset): cache_batch_text_encoder_outputs( infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype ) + if self.reg_infos is not None: + for info in batch: + if info.image_key in self.reg_infos: + self.reg_infos[info.image_key][0].text_encoder_outputs_npz = te_out_npz + self.reg_infos[info.image_key][0].te_cache_checked = True + self.reg_infos[info.image_key][0]. = info.text_encoder_outputs1 = hidden_state1 + self.reg_infos[info.image_key][0]. = info.text_encoder_outputs2 = hidden_state2 + self.reg_infos[info.image_key][0]. = info.text_encoder_pool2 = pool2 def get_image_size(self, image_path): return imagesize.get(image_path) @@ -1561,6 +1614,8 @@ class DreamBoothDataset(BaseDataset): self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight self.latents_cache = None + self.reg_infos: Dict[str, Tuple[ImageInfo, DreamBoothSubset]] = {} + self.reg_infos_index: List[str] = [] self.enable_bucket = enable_bucket if self.enable_bucket: @@ -1689,7 +1744,6 @@ class DreamBoothDataset(BaseDataset): logger.info("prepare images.") num_train_images = 0 num_reg_images = 0 - reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = [] for subset in subsets: if subset.num_repeats < 1: logger.warning( @@ -1711,7 +1765,11 @@ class DreamBoothDataset(BaseDataset): continue if subset.is_reg: - num_reg_images += subset.num_repeats * len(img_paths) + if subset.num_repeats > 1: + info.num_repeats = 1 + self.reg_infos[info.image_key] = (info, subset) + for i in range(subset.num_repeats): + self.reg_infos_index.append(info.image_key) else: num_train_images += subset.num_repeats * len(img_paths) @@ -1731,30 +1789,88 @@ class DreamBoothDataset(BaseDataset): self.num_train_images = num_train_images logger.info(f"{num_reg_images} reg images.") - if num_train_images < num_reg_images: - logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") - - if num_reg_images == 0: - logger.warning("no regularization images / 正則化画像が見つかりませんでした") - else: - # num_repeatsを計算する:どうせ大した数ではないのでループで処理する - n = 0 - first_loop = True - while n < num_train_images: - for info, subset in reg_infos: - if first_loop: - self.register_image(info, subset) - n += info.num_repeats - else: - info.num_repeats += 1 # rewrite registered info - n += 1 - if n >= num_train_images: - break - first_loop = False - self.num_reg_images = num_reg_images + def set_reg_randomize(self, reg_randomize = False): + self.reg_randomize = reg_randomize + # As first set of data is loaded before the first opportunity to shuffle, will need to force reset self.reg_infos_index_traverser and reinitialize dataset + self.reg_infos_index_traverser = 0 + self.bucket_manager = None + self.incremental_reg_load(True) + def subset_loaded_count(self): + count_str = "" + for index, subset in enumerate(self.subsets): + counter = 0 + count_str += f"\nSubset {index} (Class: {subset.class_tokens}): " if isinstance(subset, DreamBoothSubset) and subset.class_tokens is not None else f"\nSubset {index}: " + img_keys = [key for key, value in self.image_to_subset.items() if value == subset] + for img_key in img_keys: + counter += self.image_data[img_key].num_repeats + count_str += f"{counter}/{subset.img_count * subset.num_repeats}" + count_str += f"\nSubset dir: {subset.image_dir}" if subset.image_dir is not None else "" + count_str += f"\n\n" + logger.info(count_str) + + def incremental_reg_load(self, make_bucket = False): + #override to for loading random reg images + distributed_state = PartialState() + + if self.num_reg_images == 0: + logger.warning("no regularization images / 正則化画像が見つかりませんでした") + return + if self.num_train_images < self.num_reg_images: + logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") + + if not self.num_train_images == self.num_reg_images: + logger.info(f"Inititating loading of regularizaion images.") + for info, subset in self.reg_infos.values(): + if info.image_key in self.image_data: + self.image_data.pop(info.image_key, None) + self.image_to_subset.pop(info.image_key, None) + + temp_reg_infos = copy.deepcopy(self.reg_infos) + n = 0 + first_loop = True + logger.info(f"self.reg_infos_index_traverser at: {self.reg_infos_index_traverser}\n reg_infos_index len = {len(self.reg_infos_index)}") + reg_img_log = f"\nDataset seed: {self.seed}" + start_index = self.reg_infos_index_traverser + + while n < self.num_train_images : + if self.reg_randomize and self.reg_infos_index_traverser == 0: + if distributed_state.num_processes > 1: + if not distributed_state.is_main_process: + self.reg_infos_index = [] + else: + random.shuffle(self.reg_infos_index) + distributed_state.wait_for_everyone() + self.reg_infos_index = gather_object(self.reg_infos_index) + else: + random.shuffle(self.reg_infos_index) + info, subset = temp_reg_infos[self.reg_infos_index[self.reg_infos_index_traverser]] + if info.image_key in self.image_data: + info.num_repeats += 1 # rewrite registered info + else: + self.register_image(info, subset) + + self.reg_infos_index_traverser += 1 + if self.reg_infos_index_traverser % len(self.reg_infos_index) == 0: + self.reg_infos_index_traverser = 0 + ''' + if n < 5: + reg_img_log += f"\nRegistering image: {info.absolute_path}, count: {info.num_repeats}" + ''' + n += 1 + + # logger.info(reg_img_log) + if distributed_state.is_main_process: + self.subset_loaded_count() + self.bucket_manager = None + if make_bucket: + self.make_buckets() + del temp_reg_infos + else: + logger.warning(f"Number of training images({self.num_train_images}) is the same as number of regularization images({self.num_reg_images}).\nSkipping randomized/incremental loading of regularization images.") + class FineTuningDataset(BaseDataset): def __init__( self, @@ -2098,6 +2214,11 @@ class ControlNetDataset(BaseDataset): self.conditioning_image_transforms = IMAGE_TRANSFORMS + def incremental_reg_load(self, make_bucket = False): + self.dreambooth_dataset_delegate.incremental_reg_load() + if make_bucket: + self.make_buckets() + def make_buckets(self): self.dreambooth_dataset_delegate.make_buckets() self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager @@ -2185,9 +2306,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset): for dataset in self.datasets: dataset.add_replacement(str_from, str_to) - # def make_buckets(self): - # for dataset in self.datasets: - # dataset.make_buckets() + def set_reg_randomize(self, reg_randomize = False): + for dataset in self.datasets: + dataset.set_reg_randomize(reg_randomize) + + def make_buckets(self): + for dataset in self.datasets: + dataset.make_buckets() def enable_XTI(self, *args, **kwargs): for dataset in self.datasets: @@ -2234,7 +2359,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset): def disable_token_padding(self): for dataset in self.datasets: dataset.disable_token_padding() + + def incremental_reg_load(self, make_bucket = False): + for dataset in self.datasets: + dataset.incremental_reg_load(make_bucket) + def __len__(self): + self.cumulative_sizes = self.cumsum(self.datasets) + return self.cumulative_sizes[-1] def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool): expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意 @@ -3579,6 +3711,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", ) + parser.add_argument( + "--incremental_reg_load", + action="store_true", + help="Forces reload of regularization images at each Epoch. Will sequentially load regularization images unless '--randomized_regularization_image' is set. Useful if there are more regularization images than training images", + ) + parser.add_argument( + "--randomized_regularization_image", + action="store_true", + help="Shuffles regularization images to even out distribution. Useful if there are more regularization images than training images", + ) if support_dreambooth: # DreamBooth training