mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
merge with skip_cache_check, bugfix
This commit is contained in:
@@ -437,11 +437,10 @@ class LatentsCachingStrategy:
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
# In multinode training, os.path will hang, but np.load not sure.
|
||||
if not self.skip_npz_check:
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
|
||||
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
|
||||
|
||||
@@ -238,10 +238,10 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
|
||||
@@ -137,6 +137,7 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
||||
SKIP_NPZ_PATH_CHECK = False
|
||||
|
||||
# Trigger by args.skip_cache_check
|
||||
def set_skip_npz_path_check(skip: bool):
|
||||
global SKIP_NPZ_PATH_CHECK
|
||||
SKIP_NPZ_PATH_CHECK = skip
|
||||
@@ -484,7 +485,7 @@ class BaseSubset:
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_split = validation_split
|
||||
|
||||
self.skip_npz_check = skip_npz_check # For multinode training
|
||||
self.skip_npz_check = skip_npz_check or SKIP_NPZ_PATH_CHECK # For multinode training
|
||||
|
||||
class DreamBoothSubset(BaseSubset):
|
||||
def __init__(
|
||||
@@ -2168,15 +2169,21 @@ class FineTuningDataset(BaseDataset):
|
||||
debug_dataset: bool,
|
||||
validation_seed: int,
|
||||
validation_split: float,
|
||||
skip_npz_check: Optional[bool] = SKIP_NPZ_PATH_CHECK,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.skip_npz_check = skip_npz_check or SKIP_NPZ_PATH_CHECK
|
||||
if self.skip_npz_check:
|
||||
logger.info(f"Skip (VAE) latent checking enabled. Will assign all data path as *.npz directly.")
|
||||
|
||||
self.num_train_images = 0
|
||||
self.num_reg_images = 0
|
||||
|
||||
for subset in subsets:
|
||||
#logger.info(f"image_dir: {subset.image_dir}")
|
||||
|
||||
if subset.num_repeats < 1:
|
||||
logger.warning(
|
||||
f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
|
||||
@@ -2209,23 +2216,27 @@ class FineTuningDataset(BaseDataset):
|
||||
# path情報を作る
|
||||
abs_path = None
|
||||
|
||||
# まず画像を優先して探す
|
||||
if npz_path_exists(image_key):
|
||||
abs_path = image_key
|
||||
#For speed (make sure it has been checked before training)
|
||||
if self.skip_npz_check:
|
||||
abs_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
else:
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
paths = glob_images(subset.image_dir, image_key)
|
||||
if len(paths) > 0:
|
||||
abs_path = paths[0]
|
||||
|
||||
# なければnpzを探す
|
||||
if abs_path is None:
|
||||
if npz_path_exists(os.path.splitext(image_key)[0] + ".npz"):
|
||||
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||
# まず画像を優先して探す
|
||||
if npz_path_exists(image_key):
|
||||
abs_path = image_key
|
||||
else:
|
||||
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if npz_path_exists(npz_path):
|
||||
abs_path = npz_path
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
paths = glob_images(subset.image_dir, image_key)
|
||||
if len(paths) > 0:
|
||||
abs_path = paths[0]
|
||||
|
||||
# なければnpzを探す
|
||||
if abs_path is None:
|
||||
if npz_path_exists(os.path.splitext(image_key)[0] + ".npz"):
|
||||
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||
else:
|
||||
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if npz_path_exists(npz_path):
|
||||
abs_path = npz_path
|
||||
|
||||
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
|
||||
|
||||
@@ -2258,8 +2269,13 @@ class FineTuningDataset(BaseDataset):
|
||||
image_info.image_size = img_md.get("train_resolution")
|
||||
|
||||
if not subset.color_aug and not subset.random_crop:
|
||||
# if npz exists, use them
|
||||
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
|
||||
# Direct asign to skip most path checking.
|
||||
if self.skip_npz_check:
|
||||
image_info.latents_npz = abs_path
|
||||
image_info.latents_npz_flipped = abs_path.replace(".npz", "_flip.npz") if subset.flip_aug else None
|
||||
else:
|
||||
# if npz exists, use them
|
||||
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key, abs_path)
|
||||
|
||||
self.register_image(image_info, subset)
|
||||
|
||||
@@ -4618,13 +4634,6 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
||||
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)",
|
||||
)
|
||||
|
||||
def add_skip_check_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--skip_npz_existence_check",
|
||||
action="store_true",
|
||||
help="skip check for images and latents existence, useful if your storage has low random access speed / 画像とlatentの存在チェックをスキップする。ストレージのランダムアクセス速度が遅い場合に有用",
|
||||
)
|
||||
|
||||
def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
if not args.config_file:
|
||||
return args
|
||||
|
||||
@@ -465,6 +465,8 @@ class NativeTrainer:
|
||||
training_started_at = time.time()
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
if args.skip_cache_check:
|
||||
train_util.set_skip_npz_path_check(True)
|
||||
deepspeed_utils.prepare_deepspeed_args(args)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
@@ -1651,7 +1653,6 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_skip_check_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
@@ -1763,8 +1764,6 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
if args.skip_npz_existence_check:
|
||||
train_util.set_skip_npz_path_check(True)
|
||||
|
||||
trainer = NativeTrainer()
|
||||
trainer.train(args)
|
||||
|
||||
Reference in New Issue
Block a user