mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix existing npz skip feature
This commit is contained in:
@@ -63,25 +63,6 @@ def main(args):
|
||||
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
||||
return
|
||||
|
||||
# 既に存在するファイルをfilterする
|
||||
if args.skip_existing:
|
||||
filtered = []
|
||||
for image_path in image_paths:
|
||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||
|
||||
npz_file_name_flip = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"
|
||||
if os.path.exists(npz_file_name_flip):
|
||||
if not args.flip_aug:
|
||||
continue
|
||||
|
||||
npz_file_name_flip = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
|
||||
if os.path.exists(npz_file_name_flip):
|
||||
continue
|
||||
|
||||
filtered.apppend(image_path)
|
||||
print(f"number of skipped images (npz already exists) / npzファイルが存在するためスキップした画像数: {len(image_paths) - len(filtered)}")
|
||||
image_paths = filtered
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
@@ -128,8 +109,8 @@ def main(args):
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = train_util.ImageLoadingDataset(image_paths)
|
||||
data = torch.util.data.DataLoader(dataset, batch_size=1, shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
||||
data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
|
||||
@@ -153,6 +134,7 @@ def main(args):
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
|
||||
# 本当はこの部分もDataSetに持っていけば高速化できるがいろいろ大変
|
||||
aspect_ratio = image.width / image.height
|
||||
ar_errors = bucket_aspect_ratios - aspect_ratio
|
||||
bucket_id = np.abs(ar_errors).argmin()
|
||||
@@ -176,6 +158,25 @@ def main(args):
|
||||
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
||||
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||
|
||||
# 既に存在するファイルがあればshapeを確認して同じならskipする
|
||||
if args.skip_existing:
|
||||
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
|
||||
if args.flip_aug:
|
||||
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
|
||||
|
||||
found = True
|
||||
for npz_file in npz_files:
|
||||
if not os.path.exists(npz_file):
|
||||
found = False
|
||||
break
|
||||
|
||||
dat = np.load(npz_file)['arr_0']
|
||||
if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
|
||||
found = False
|
||||
break
|
||||
if found:
|
||||
continue
|
||||
|
||||
# 画像をリサイズしてトリミングする
|
||||
# PILにinter_areaがないのでcv2で……
|
||||
image = np.array(image)
|
||||
|
||||
Reference in New Issue
Block a user