From da05ad63390317fb43893e55689db3aff982242b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Dec 2022 21:23:40 +0900 Subject: [PATCH] Fix npz file name for images with dots #12 --- finetune/prepare_buckets_latents.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index e2cebe8d..00f847a1 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -130,14 +130,16 @@ def main(args): latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype) for (image_key, reso, _), latent in zip(bucket, latents): - np.savez(os.path.join(args.train_data_dir, os.path.splitext(os.path.basename(image_key))[0]), latent) + npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key + np.savez(os.path.join(args.train_data_dir, npz_file_name), latent) # flip if args.flip_aug: latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない for (image_key, reso, _), latent in zip(bucket, latents): - np.savez(os.path.join(args.train_data_dir, os.path.splitext(os.path.basename(image_key))[0] + '_flip'), latent) + npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key + np.savez(os.path.join(args.train_data_dir, npz_file_name + '_flip'), latent) bucket.clear()