merge with skip_cache_check, bugfix

This commit is contained in:
Darren Laurie
2025-03-09 02:50:55 +08:00
parent 190df71e6d
commit 42b8d79203
4 changed files with 40 additions and 33 deletions

View File

@@ -437,11 +437,10 @@ class LatentsCachingStrategy:
if not self.cache_to_disk:
return False
# In multinode training, os.path will hang, but np.load not sure.
if not self.skip_npz_check:
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
if not os.path.exists(npz_path):
return False
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)

View File

@@ -238,10 +238,10 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
if not os.path.exists(npz_path):
return False
try:
npz = np.load(npz_path)

View File

@@ -137,6 +137,7 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
SKIP_NPZ_PATH_CHECK = False
# Trigger by args.skip_cache_check
def set_skip_npz_path_check(skip: bool):
global SKIP_NPZ_PATH_CHECK
SKIP_NPZ_PATH_CHECK = skip
@@ -484,7 +485,7 @@ class BaseSubset:
self.validation_seed = validation_seed
self.validation_split = validation_split
self.skip_npz_check = skip_npz_check # For multinode training
self.skip_npz_check = skip_npz_check or SKIP_NPZ_PATH_CHECK # For multinode training
class DreamBoothSubset(BaseSubset):
def __init__(
@@ -2168,15 +2169,21 @@ class FineTuningDataset(BaseDataset):
debug_dataset: bool,
validation_seed: int,
validation_split: float,
skip_npz_check: Optional[bool] = SKIP_NPZ_PATH_CHECK,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
self.batch_size = batch_size
self.skip_npz_check = skip_npz_check or SKIP_NPZ_PATH_CHECK
if self.skip_npz_check:
logger.info(f"Skip (VAE) latent checking enabled. Will assign all data path as *.npz directly.")
self.num_train_images = 0
self.num_reg_images = 0
for subset in subsets:
#logger.info(f"image_dir: {subset.image_dir}")
if subset.num_repeats < 1:
logger.warning(
f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
@@ -2209,23 +2216,27 @@ class FineTuningDataset(BaseDataset):
# path情報を作る
abs_path = None
# まず画像を優先して探す
if npz_path_exists(image_key):
abs_path = image_key
#For speed (make sure it has been checked before training)
if self.skip_npz_check:
abs_path = os.path.join(subset.image_dir, image_key + ".npz")
else:
# わりといい加減だがいい方法が思いつかん
paths = glob_images(subset.image_dir, image_key)
if len(paths) > 0:
abs_path = paths[0]
# なければnpzを探す
if abs_path is None:
if npz_path_exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
# まず画像を優先して探す
if npz_path_exists(image_key):
abs_path = image_key
else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if npz_path_exists(npz_path):
abs_path = npz_path
# わりといい加減だがいい方法が思いつかん
paths = glob_images(subset.image_dir, image_key)
if len(paths) > 0:
abs_path = paths[0]
# なければnpzを探す
if abs_path is None:
if npz_path_exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if npz_path_exists(npz_path):
abs_path = npz_path
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
@@ -2258,8 +2269,13 @@ class FineTuningDataset(BaseDataset):
image_info.image_size = img_md.get("train_resolution")
if not subset.color_aug and not subset.random_crop:
# if npz exists, use them
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
# Direct asign to skip most path checking.
if self.skip_npz_check:
image_info.latents_npz = abs_path
image_info.latents_npz_flipped = abs_path.replace(".npz", "_flip.npz") if subset.flip_aug else None
else:
# if npz exists, use them
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key, abs_path)
self.register_image(image_info, subset)
@@ -4618,13 +4634,6 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存するsave_model_as未指定時",
)
def add_skip_check_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--skip_npz_existence_check",
action="store_true",
help="skip check for images and latents existence, useful if your storage has low random access speed / 画像とlatentの存在チェックをスキップする。ストレージのランダムアクセス速度が遅い場合に有用",
)
def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
if not args.config_file:
return args

View File

@@ -465,6 +465,8 @@ class NativeTrainer:
training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
if args.skip_cache_check:
train_util.set_skip_npz_path_check(True)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
@@ -1651,7 +1653,6 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_skip_check_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
@@ -1763,8 +1764,6 @@ if __name__ == "__main__":
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
if args.skip_npz_existence_check:
train_util.set_skip_npz_path_check(True)
trainer = NativeTrainer()
trainer.train(args)