From 42b8d79203411103844b0dd05ef101d3e172f266 Mon Sep 17 00:00:00 2001 From: Darren Laurie <6DammK9@gmail.com> Date: Sun, 9 Mar 2025 02:50:55 +0800 Subject: [PATCH] merge with skip_cache_check, bugfix --- library/strategy_base.py | 5 ++-- library/strategy_sdxl.py | 4 +-- library/train_util.py | 59 +++++++++++++++++++++++----------------- train_native.py | 5 ++-- 4 files changed, 40 insertions(+), 33 deletions(-) diff --git a/library/strategy_base.py b/library/strategy_base.py index 5d7e9593..552a44ec 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -437,11 +437,10 @@ class LatentsCachingStrategy: if not self.cache_to_disk: return False # In multinode training, os.path will hang, but np.load not sure. - if not self.skip_npz_check: - if not os.path.exists(npz_path): - return False if self.skip_disk_cache_validity_check: return True + if not os.path.exists(npz_path): + return False expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py index 6b3e2afa..15cb752c 100644 --- a/library/strategy_sdxl.py +++ b/library/strategy_sdxl.py @@ -238,10 +238,10 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): def is_disk_cached_outputs_expected(self, npz_path: str): if not self.cache_to_disk: return False - if not os.path.exists(npz_path): - return False if self.skip_disk_cache_validity_check: return True + if not os.path.exists(npz_path): + return False try: npz = np.load(npz_path) diff --git a/library/train_util.py b/library/train_util.py index 5eecf2d9..a0211901 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -137,6 +137,7 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" SKIP_NPZ_PATH_CHECK = False +# Trigger by args.skip_cache_check def set_skip_npz_path_check(skip: bool): global SKIP_NPZ_PATH_CHECK SKIP_NPZ_PATH_CHECK = skip @@ -484,7 +485,7 @@ class BaseSubset: self.validation_seed = validation_seed self.validation_split = validation_split - self.skip_npz_check = skip_npz_check # For multinode training + self.skip_npz_check = skip_npz_check or SKIP_NPZ_PATH_CHECK # For multinode training class DreamBoothSubset(BaseSubset): def __init__( @@ -2168,15 +2169,21 @@ class FineTuningDataset(BaseDataset): debug_dataset: bool, validation_seed: int, validation_split: float, + skip_npz_check: Optional[bool] = SKIP_NPZ_PATH_CHECK, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) self.batch_size = batch_size + self.skip_npz_check = skip_npz_check or SKIP_NPZ_PATH_CHECK + if self.skip_npz_check: + logger.info(f"Skip (VAE) latent checking enabled. Will assign all data path as *.npz directly.") self.num_train_images = 0 self.num_reg_images = 0 for subset in subsets: + #logger.info(f"image_dir: {subset.image_dir}") + if subset.num_repeats < 1: logger.warning( f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" @@ -2209,23 +2216,27 @@ class FineTuningDataset(BaseDataset): # path情報を作る abs_path = None - # まず画像を優先して探す - if npz_path_exists(image_key): - abs_path = image_key + #For speed (make sure it has been checked before training) + if self.skip_npz_check: + abs_path = os.path.join(subset.image_dir, image_key + ".npz") else: - # わりといい加減だがいい方法が思いつかん - paths = glob_images(subset.image_dir, image_key) - if len(paths) > 0: - abs_path = paths[0] - - # なければnpzを探す - if abs_path is None: - if npz_path_exists(os.path.splitext(image_key)[0] + ".npz"): - abs_path = os.path.splitext(image_key)[0] + ".npz" + # まず画像を優先して探す + if npz_path_exists(image_key): + abs_path = image_key else: - npz_path = os.path.join(subset.image_dir, image_key + ".npz") - if npz_path_exists(npz_path): - abs_path = npz_path + # わりといい加減だがいい方法が思いつかん + paths = glob_images(subset.image_dir, image_key) + if len(paths) > 0: + abs_path = paths[0] + + # なければnpzを探す + if abs_path is None: + if npz_path_exists(os.path.splitext(image_key)[0] + ".npz"): + abs_path = os.path.splitext(image_key)[0] + ".npz" + else: + npz_path = os.path.join(subset.image_dir, image_key + ".npz") + if npz_path_exists(npz_path): + abs_path = npz_path assert abs_path is not None, f"no image / 画像がありません: {image_key}" @@ -2258,8 +2269,13 @@ class FineTuningDataset(BaseDataset): image_info.image_size = img_md.get("train_resolution") if not subset.color_aug and not subset.random_crop: - # if npz exists, use them - image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key) + # Direct asign to skip most path checking. + if self.skip_npz_check: + image_info.latents_npz = abs_path + image_info.latents_npz_flipped = abs_path.replace(".npz", "_flip.npz") if subset.flip_aug else None + else: + # if npz exists, use them + image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key, abs_path) self.register_image(image_info, subset) @@ -4618,13 +4634,6 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser): help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)", ) -def add_skip_check_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--skip_npz_existence_check", - action="store_true", - help="skip check for images and latents existence, useful if your storage has low random access speed / 画像とlatentの存在チェックをスキップする。ストレージのランダムアクセス速度が遅い場合に有用", - ) - def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser): if not args.config_file: return args diff --git a/train_native.py b/train_native.py index 45c94a23..79418f4d 100644 --- a/train_native.py +++ b/train_native.py @@ -465,6 +465,8 @@ class NativeTrainer: training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + if args.skip_cache_check: + train_util.set_skip_npz_path_check(True) deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) @@ -1651,7 +1653,6 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_masked_loss_arguments(parser) deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_sd_saving_arguments(parser) - train_util.add_skip_check_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) @@ -1763,8 +1764,6 @@ if __name__ == "__main__": args = parser.parse_args() train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) - if args.skip_npz_existence_check: - train_util.set_skip_npz_path_check(True) trainer = NativeTrainer() trainer.train(args)