Fix existing npz skip feature

This commit is contained in:
Kohya S
2023-02-03 21:05:14 +09:00
parent 73d612ff9c
commit 76f53429be

View File

@@ -63,25 +63,6 @@ def main(args):
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
return
# 既に存在するファイルをfilterする
if args.skip_existing:
filtered = []
for image_path in image_paths:
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
npz_file_name_flip = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"
if os.path.exists(npz_file_name_flip):
if not args.flip_aug:
continue
npz_file_name_flip = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
if os.path.exists(npz_file_name_flip):
continue
filtered.apppend(image_path)
print(f"number of skipped images (npz already exists) / npzファイルが存在するためスキップした画像数: {len(image_paths) - len(filtered)}")
image_paths = filtered
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
@@ -128,7 +109,7 @@ def main(args):
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.util.data.DataLoader(dataset, batch_size=1, shuffle=False,
data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
@@ -153,6 +134,7 @@ def main(args):
if image_key not in metadata:
metadata[image_key] = {}
# 本当はこの部分もDataSetに持っていけば高速化できるがいろいろ大変
aspect_ratio = image.width / image.height
ar_errors = bucket_aspect_ratios - aspect_ratio
bucket_id = np.abs(ar_errors).argmin()
@@ -176,6 +158,25 @@ def main(args):
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
# 既に存在するファイルがあればshapeを確認して同じならskipする
if args.skip_existing:
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
if args.flip_aug:
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
found = True
for npz_file in npz_files:
if not os.path.exists(npz_file):
found = False
break
dat = np.load(npz_file)['arr_0']
if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
found = False
break
if found:
continue
# 画像をリサイズしてトリミングする
# PILにinter_areaがないのでcv2で……
image = np.array(image)