diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 537626d8..d1b9ea25 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -63,25 +63,6 @@ def main(args): print(f"no metadata / メタデータファイルがありません: {args.in_json}") return - # 既に存在するファイルをfilterする - if args.skip_existing: - filtered = [] - for image_path in image_paths: - image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] - - npz_file_name_flip = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz" - if os.path.exists(npz_file_name_flip): - if not args.flip_aug: - continue - - npz_file_name_flip = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz" - if os.path.exists(npz_file_name_flip): - continue - - filtered.apppend(image_path) - print(f"number of skipped images (npz already exists) / npzファイルが存在するためスキップした画像数: {len(image_paths) - len(filtered)}") - image_paths = filtered - weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 @@ -128,8 +109,8 @@ def main(args): # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: dataset = train_util.ImageLoadingDataset(image_paths) - data = torch.util.data.DataLoader(dataset, batch_size=1, shuffle=False, - num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) + data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, + num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) else: data = [[(None, ip)] for ip in image_paths] @@ -153,6 +134,7 @@ def main(args): if image_key not in metadata: metadata[image_key] = {} + # 本当はこの部分もDataSetに持っていけば高速化できるがいろいろ大変 aspect_ratio = image.width / image.height ar_errors = bucket_aspect_ratios - aspect_ratio bucket_id = np.abs(ar_errors).argmin() @@ -176,6 +158,25 @@ def main(args): assert resized_size[0] >= reso[0] and resized_size[1] >= reso[ 1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}" + # 既に存在するファイルがあればshapeを確認して同じならskipする + if args.skip_existing: + npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"] + if args.flip_aug: + npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz") + + found = True + for npz_file in npz_files: + if not os.path.exists(npz_file): + found = False + break + + dat = np.load(npz_file)['arr_0'] + if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認 + found = False + break + if found: + continue + # 画像をリサイズしてトリミングする # PILにinter_areaがないのでcv2で…… image = np.array(image)