add GIT captioning, refactoring, DataLoader

This commit is contained in:
Kohya S
2023-02-03 08:45:33 +09:00
parent 8c3a52ecc9
commit 57d8483eaf
9 changed files with 479 additions and 144 deletions

View File

@@ -1,20 +1,16 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse
import glob
import os
import json
from tqdm import tqdm
import numpy as np
from diffusers import AutoencoderKL
from PIL import Image
import cv2
import torch
from torchvision import transforms
import library.model_util as model_util
import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -26,6 +22,16 @@ IMAGE_TRANSFORMS = transforms.Compose(
)
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def get_latents(vae, images, weight_dtype):
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
img_tensors = torch.stack(img_tensors)
@@ -35,9 +41,18 @@ def get_latents(vae, images, weight_dtype):
return latents
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
if is_full_path:
base_name = os.path.splitext(os.path.basename(image_key))[0]
else:
base_name = image_key
if flip:
base_name += '_flip'
return os.path.join(data_dir, base_name)
def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
if os.path.exists(args.in_json):
@@ -48,6 +63,25 @@ def main(args):
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
return
# 既に存在するファイルをfilterする
if args.skip_existing:
filtered = []
for image_path in image_paths:
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
npz_file_name_flip = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"
if os.path.exists(npz_file_name_flip):
if not args.flip_aug:
continue
npz_file_name_flip = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
if os.path.exists(npz_file_name_flip):
continue
filtered.apppend(image_path)
print(f"number of skipped images (npz already exists) / npzファイルが存在するためスキップした画像数: {len(image_paths) - len(filtered)}")
image_paths = filtered
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
@@ -70,15 +104,55 @@ def main(args):
buckets_imgs = [[] for _ in range(len(bucket_resos))]
bucket_counts = [0 for _ in range(len(bucket_resos))]
img_ar_errors = []
for i, image_path in enumerate(tqdm(image_paths, smoothing=0.0)):
def process_batch(is_last):
for j in range(len(buckets_imgs)):
bucket = buckets_imgs[j]
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
for (image_key, _, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
np.savez(npz_file_name, latent)
# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, _, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
np.savez(npz_file_name, latent)
bucket.clear()
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.util.data.DataLoader(dataset, batch_size=1, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
for data_entry in tqdm(data, smoothing=0.0):
if data_entry[0] is None:
continue
img_tensor, image_path = data_entry[0]
if img_tensor is not None:
image = transforms.functional.to_pil_image(img_tensor)
else:
try:
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
if image_key not in metadata:
metadata[image_key] = {}
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
aspect_ratio = image.width / image.height
ar_errors = bucket_aspect_ratios - aspect_ratio
bucket_id = np.abs(ar_errors).argmin()
@@ -123,25 +197,10 @@ def main(args):
metadata[image_key]['train_resolution'] = reso
# バッチを推論するか判定して推論する
is_last = i == len(image_paths) - 1
for j in range(len(buckets_imgs)):
bucket = buckets_imgs[j]
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
process_batch(False)
for (image_key, reso, _), latent in zip(bucket, latents):
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key
np.savez(os.path.join(args.train_data_dir, npz_file_name), latent)
# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, reso, _), latent in zip(bucket, latents):
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key
np.savez(os.path.join(args.train_data_dir, npz_file_name + '_flip'), latent)
bucket.clear()
# 残りを処理する
process_batch(True)
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
print(f"bucket {i} {reso}: {count}")
@@ -162,8 +221,10 @@ if __name__ == '__main__':
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)')
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--max_resolution", type=str, default="512,512",
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
@@ -174,6 +235,8 @@ if __name__ == '__main__':
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--flip_aug", action="store_true",
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
parser.add_argument("--skip_existing", action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ")
args = parser.parse_args()
main(args)