diff --git a/library/config_util.py b/library/config_util.py index 10b2457f..5aa5b387 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -576,9 +576,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): logger.info(f"[Dataset {i}]") - dataset.make_buckets() dataset.set_seed(seed) - + dataset.incremental_reg_load() + dataset.make_buckets() + return DatasetGroup(datasets) diff --git a/library/train_util.py b/library/train_util.py index b9d08f25..05e21f83 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -29,6 +29,7 @@ import hashlib import subprocess from io import BytesIO import toml +import copy from tqdm import tqdm @@ -164,6 +165,8 @@ class ImageInfo: self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + self.latent_cache_checked: bool = False + self.te_cache_checked: bool = False class BucketManager: @@ -653,6 +656,11 @@ class BaseDataset(torch.utils.data.Dataset): # caching self.caching_mode = None # None, 'latents', 'text' + # lists for incremental loading of regularization images + self.reg_infos = None + self.reg_infos_index = None + self.reg_randomize = False + def adjust_min_max_bucket_reso_by_steps( self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int ) -> Tuple[int, int]: @@ -684,6 +692,12 @@ class BaseDataset(torch.utils.data.Dataset): def set_seed(self, seed): self.seed = seed + def set_reg_randomize(self, reg_randomize = False): + self.reg_randomize = reg_randomize + + def incremental_reg_load(self, make_bucket = False): # Placeholder method, does nothing unless overridden in subclasses. + return + def set_caching_mode(self, mode): self.caching_mode = mode @@ -951,11 +965,14 @@ class BaseDataset(torch.utils.data.Dataset): if self.enable_bucket: self.bucket_info = {"buckets": {}} logger.info("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") + batch_count: int = 0 for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)): count = len(bucket) if count > 0: + batch_count += math.ceil(len(bucket) / self.batch_size) self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} - logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}") + logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}, batches: {int(math.ceil(len(bucket) / self.batch_size))}") + logger.info(f"Total batch count: {batch_count}") if len(img_ar_errors) == 0: mean_img_ar_error = 0 # avoid NaN @@ -967,6 +984,7 @@ class BaseDataset(torch.utils.data.Dataset): # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる self.buckets_indices: List[BucketBatchIndex] = [] + self.buckets_indices.clear() for bucket_index, bucket in enumerate(self.bucket_manager.buckets): batch_count = int(math.ceil(len(bucket) / self.batch_size)) for batch_index in range(batch_count): @@ -1025,6 +1043,10 @@ class BaseDataset(torch.utils.data.Dataset): logger.info("caching latents.") image_infos = list(self.image_data.values()) + image_infos = list(filter(lambda info: info.latent_cache_checked == False, image_infos)) + if len(image_infos) == 0: + logger.info("All images latents previously checked and cached. Skipping.") + return # sort by resolution image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) @@ -1054,11 +1076,17 @@ class BaseDataset(torch.utils.data.Dataset): subset = self.image_to_subset[info.image_key] if info.latents_npz is not None: # fine tuning dataset + info.latent_cache_checked = True + if self.reg_infos is not None and info.image_key in self.reg_infos: + self.reg_infos[info.image_key][0].latent_cache_checked = True continue # check disk cache exists and size of latents if cache_to_disk: info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" + info.latent_cache_checked = True + if self.reg_infos is not None and info.image_key in self.reg_infos: + self.reg_infos[info.image_key][0].latent_cache_checked = True if not is_main_process: # store to info only continue @@ -1094,6 +1122,15 @@ class BaseDataset(torch.utils.data.Dataset): logger.info("caching latents...") for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) + if self.reg_infos is not None: + for info in batch: + if info.image_key in self.reg_infos: + self.reg_infos[info.image_key][0].latents_npz = info.latents_npz + self.reg_infos[info.image_key][0].latents_original_size = info.latents_original_size + self.reg_infos[info.image_key][0].latents_crop_ltrb = info.latents_crop_ltrb + self.reg_infos[info.image_key][0].latents_crop_ltrb = info.latents_flipped + self.reg_infos[info.image_key][0].latents = info.latents + self.reg_infos[info.image_key][0].alpha_mask = info.alpha_mask # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する @@ -1107,6 +1144,10 @@ class BaseDataset(torch.utils.data.Dataset): # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching text encoder outputs.") image_infos = list(self.image_data.values()) + image_infos = list(filter(lambda info: info.te_cache_checked == False, image_infos)) + if len(image_infos) == 0: + logger.info("Text encoder outputs for all images previously checked and cached. Skipping.") + return logger.info("checking cache existence...") image_infos_to_cache = [] @@ -1115,6 +1156,10 @@ class BaseDataset(torch.utils.data.Dataset): if cache_to_disk: te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX info.text_encoder_outputs_npz = te_out_npz + info.te_cache_checked = True + if self.reg_infos is not None: + self.reg_infos[info.image_key][0].text_encoder_outputs_npz = te_out_npz + self.reg_infos[info.image_key][0].te_cache_checked = True if not is_main_process: # store to info only continue @@ -1157,6 +1202,14 @@ class BaseDataset(torch.utils.data.Dataset): cache_batch_text_encoder_outputs( infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype ) + if self.reg_infos is not None: + for info in batch: + if info.image_key in self.reg_infos: + self.reg_infos[info.image_key][0].text_encoder_outputs_npz = te_out_npz + self.reg_infos[info.image_key][0].te_cache_checked = True + self.reg_infos[info.image_key][0].text_encoder_outputs1 = info.text_encoder_outputs1 + self.reg_infos[info.image_key][0].text_encoder_outputs2 = info.text_encoder_outputs2 + self.reg_infos[info.image_key][0].text_encoder_pool2 = info.text_encoder_pool2 def get_image_size(self, image_path): return imagesize.get(image_path) @@ -1561,6 +1614,9 @@ class DreamBoothDataset(BaseDataset): self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight self.latents_cache = None + self.reg_infos: Dict[str, Tuple[ImageInfo, DreamBoothSubset]] = {} + self.reg_infos_index: List[str] = [] + self.reg_infos_index_traverser = 0 self.enable_bucket = enable_bucket if self.enable_bucket: @@ -1689,7 +1745,6 @@ class DreamBoothDataset(BaseDataset): logger.info("prepare images.") num_train_images = 0 num_reg_images = 0 - reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = [] for subset in subsets: if subset.num_repeats < 1: logger.warning( @@ -1720,7 +1775,11 @@ class DreamBoothDataset(BaseDataset): if size is not None: info.image_size = size if subset.is_reg: - reg_infos.append((info, subset)) + if subset.num_repeats > 1: + info.num_repeats = 1 + self.reg_infos[info.image_key] = (info, subset) + for i in range(subset.num_repeats): + self.reg_infos_index.append(info.image_key) else: self.register_image(info, subset) @@ -1731,30 +1790,89 @@ class DreamBoothDataset(BaseDataset): self.num_train_images = num_train_images logger.info(f"{num_reg_images} reg images.") - if num_train_images < num_reg_images: + self.num_reg_images = num_reg_images + self.reg_infos_index_traverser = 0 + + def set_reg_randomize(self, reg_randomize = False): + self.reg_randomize = reg_randomize + # As first set of data is loaded before the first opportunity to shuffle, will need to force reset self.reg_infos_index_traverser and reinitialize dataset + self.reg_infos_index_traverser = 0 + self.bucket_manager = None + self.incremental_reg_load(True) + + def subset_loaded_count(self): + count_str = "" + for index, subset in enumerate(self.subsets): + counter = 0 + count_str += f"\nSubset {index} (Class: {subset.class_tokens}): " if isinstance(subset, DreamBoothSubset) and subset.class_tokens is not None else f"\nSubset {index}: " + img_keys = [key for key, value in self.image_to_subset.items() if value == subset] + for img_key in img_keys: + counter += self.image_data[img_key].num_repeats + count_str += f"{counter}/{subset.img_count * subset.num_repeats}" + count_str += f"\nSubset dir: {subset.image_dir}" if subset.image_dir is not None else "" + count_str += f"\n\n" + logger.info(count_str) + + def incremental_reg_load(self, make_bucket = False): + #override to for loading random reg images + distributed_state = PartialState() + + if self.num_reg_images == 0: + logger.warning("no regularization images / 正則化画像が見つかりませんでした") + return + if self.num_train_images < self.num_reg_images: logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") - if num_reg_images == 0: - logger.warning("no regularization images / 正則化画像が見つかりませんでした") - else: - # num_repeatsを計算する:どうせ大した数ではないのでループで処理する + if not self.num_train_images == self.num_reg_images: + logger.info(f"Inititating loading of regularizaion images.") + for info, subset in self.reg_infos.values(): + if info.image_key in self.image_data: + self.image_data.pop(info.image_key, None) + self.image_to_subset.pop(info.image_key, None) + + temp_reg_infos = copy.deepcopy(self.reg_infos) n = 0 first_loop = True - while n < num_train_images: - for info, subset in reg_infos: - if first_loop: - self.register_image(info, subset) - n += info.num_repeats + logger.info(f"self.reg_infos_index_traverser at: {self.reg_infos_index_traverser}\n reg_infos_index len = {len(self.reg_infos_index)}") + reg_img_log = f"\nDataset seed: {self.seed}" + start_index = self.reg_infos_index_traverser + + while n < self.num_train_images : + if self.reg_randomize and self.reg_infos_index_traverser == 0: + if distributed_state.num_processes > 1: + if not distributed_state.is_main_process: + self.reg_infos_index = [] + else: + random.shuffle(self.reg_infos_index) + distributed_state.wait_for_everyone() + self.reg_infos_index = gather_object(self.reg_infos_index) else: - info.num_repeats += 1 # rewrite registered info - n += 1 - if n >= num_train_images: - break - first_loop = False - - self.num_reg_images = num_reg_images + random.shuffle(self.reg_infos_index) + info, subset = temp_reg_infos[self.reg_infos_index[self.reg_infos_index_traverser]] + if info.image_key in self.image_data: + info.num_repeats += 1 # rewrite registered info + else: + self.register_image(info, subset) + self.reg_infos_index_traverser += 1 + if self.reg_infos_index_traverser % len(self.reg_infos_index) == 0: + self.reg_infos_index_traverser = 0 + ''' + if n < 5: + reg_img_log += f"\nRegistering image: {info.absolute_path}, count: {info.num_repeats}" + ''' + n += 1 + # logger.info(reg_img_log) + if distributed_state.is_main_process: + self.subset_loaded_count() + self.bucket_manager = None + if make_bucket: + self.make_buckets() + del temp_reg_infos + else: + logger.warning(f"Number of training images({self.num_train_images}) is the same as number of regularization images({self.num_reg_images}).\nSkipping randomized/incremental loading of regularization images.") + class FineTuningDataset(BaseDataset): def __init__( self, @@ -2098,6 +2216,11 @@ class ControlNetDataset(BaseDataset): self.conditioning_image_transforms = IMAGE_TRANSFORMS + def incremental_reg_load(self, make_bucket = False): + self.dreambooth_dataset_delegate.incremental_reg_load() + if make_bucket: + self.make_buckets() + def make_buckets(self): self.dreambooth_dataset_delegate.make_buckets() self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager @@ -2185,9 +2308,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset): for dataset in self.datasets: dataset.add_replacement(str_from, str_to) - # def make_buckets(self): - # for dataset in self.datasets: - # dataset.make_buckets() + def set_reg_randomize(self, reg_randomize = False): + for dataset in self.datasets: + dataset.set_reg_randomize(reg_randomize) + + def make_buckets(self): + for dataset in self.datasets: + dataset.make_buckets() def enable_XTI(self, *args, **kwargs): for dataset in self.datasets: @@ -2234,7 +2361,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset): def disable_token_padding(self): for dataset in self.datasets: dataset.disable_token_padding() + + def incremental_reg_load(self, make_bucket = False): + for dataset in self.datasets: + dataset.incremental_reg_load(make_bucket) + def __len__(self): + self.cumulative_sizes = self.cumsum(self.datasets) + return self.cumulative_sizes[-1] def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool): expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意 @@ -3579,6 +3713,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", ) + parser.add_argument( + "--incremental_reg_load", + action="store_true", + help="Forces reload of regularization images at each Epoch. Will sequentially load regularization images unless '--randomized_regularization_image' is set. Useful if there are more regularization images than training images", + ) + parser.add_argument( + "--randomized_regularization_image", + action="store_true", + help="Shuffles regularization images to even out distribution. Useful if there are more regularization images than training images", + ) if support_dreambooth: # DreamBooth training diff --git a/train_network.py b/train_network.py index 6953bb17..26668b97 100644 --- a/train_network.py +++ b/train_network.py @@ -8,6 +8,7 @@ import time import json from multiprocessing import Value import toml +import copy from tqdm import tqdm @@ -135,6 +136,11 @@ class NetworkTrainer: train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) def train(self, args): + # acceleratorを準備する + logger.info("preparing accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + session_id = random.randint(0, 2**32) training_started_at = time.time() train_util.verify_training_args(args) @@ -202,8 +208,15 @@ class NetworkTrainer: current_epoch = Value("i", 0) current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + if args.incremental_reg_reload: + if args.persistent_data_loader_workers: + logger.warning("persistent_data_loader_workers has been set to False because incremental_reg_reload is enabled.") + args.persistent_data_loader_workers = False + if args.randomized_regularization_image: + # train_dataset_group.set_reg_randomize() triggers a reload to initial state with randomized regularization images. Ensure that this occurs before initial caching to prevent data mismatch + logger.info("Reloading sequentially loaded regularization images to replace with randomly selected regularization images...") + train_dataset_group.set_reg_randomize(args.randomized_regularization_image) if args.debug_dataset: train_util.debug_dataset(train_dataset_group) @@ -221,11 +234,6 @@ class NetworkTrainer: self.assert_extra_args(args, train_dataset_group) - # acceleratorを準備する - logger.info("preparing accelerator") - accelerator = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process - # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) vae_dtype = torch.float32 if args.no_half_vae else weight_dtype @@ -263,23 +271,24 @@ class NetworkTrainer: accelerator.print(f"all weights merged: {', '.join(args.base_weights)}") - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=vae_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - clean_memory_on_device(accelerator.device) - accelerator.wait_for_everyone() - - # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される - # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu - self.cache_text_encoder_outputs_if_needed( - args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype - ) + ''' + Replacing cache latents and cache text encoder outputs here with code to simulate running through self.cache_text_encoder_outputs_if_needed(). + Reduces unnecessary caching by avoiding caching until data loaded into train_dataset_group has been finalized. + This step is required to ensure text_encoders are loaded onto the correct device for the training. + Possibly should replace with method that can be overridden for different handling of TE for different models. + ''' + if args.cache_text_encoder_outputs and self.is_sdxl: + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + for t_enc in text_encoders: + t_enc.to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + if not self.is_sdxl: + accelerator.print("Text Encoder caching not supported. Overriding args.cache_text_encoder_output to False") + args.cache_text_encoder_outputs = False + for t_enc in text_encoders: + t_enc.to(accelerator.device, dtype=weight_dtype) # prepare network net_kwargs = {} @@ -368,8 +377,12 @@ class NetworkTrainer: # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, @@ -845,7 +858,6 @@ class NetworkTrainer: ) loss_recorder = train_util.LossRecorder() - del train_dataset_group # callback for step start if hasattr(accelerator.unwrap_model(network), "on_step_start"): @@ -885,6 +897,8 @@ class NetworkTrainer: for skip_epoch in range(epoch_to_start): # skip epochs logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") initial_step -= len(train_dataloader) + if args.incremental_reg_reload: + train_dataset_group.incremental_reg_load(True) # Updates the loaded dataset to the next epoch global_step = initial_step for epoch in range(epoch_to_start, num_train_epochs): @@ -893,6 +907,39 @@ class NetworkTrainer: metadata["ss_epoch"] = str(epoch + 1) + if epoch == epoch_to_start or args.incremental_reg_reload: + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される + # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu + self.cache_text_encoder_outputs_if_needed( + args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype + ) + accelerator.wait_for_everyone() # Ensure all processes sync after potential dataset/cache changes in initial_step block + + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, # This is the updated train_dataset_group + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, # Ensure n_workers is available + persistent_workers=args.persistent_data_loader_workers, + ) + accelerator.wait_for_everyone() + train_dataloader = accelerator.prepare(train_dataloader) + accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) skipped_dataloader = None @@ -1091,6 +1138,9 @@ class NetworkTrainer: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + # Load next batch of regularization images if necessary + if args.incremental_reg_reload and epoch + 1 < num_train_epochs: + train_dataset_group.incremental_reg_load(True) # end of epoch