mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add GIT captioning, refactoring, DataLoader
This commit is contained in:
@@ -1,20 +1,16 @@
|
||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||
# (c) 2022 Kohya S. @kohya_ss
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKL
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
@@ -26,6 +22,16 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
)
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return batch
|
||||
|
||||
|
||||
def get_latents(vae, images, weight_dtype):
|
||||
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
|
||||
img_tensors = torch.stack(img_tensors)
|
||||
@@ -35,9 +41,18 @@ def get_latents(vae, images, weight_dtype):
|
||||
return latents
|
||||
|
||||
|
||||
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
|
||||
if is_full_path:
|
||||
base_name = os.path.splitext(os.path.basename(image_key))[0]
|
||||
else:
|
||||
base_name = image_key
|
||||
if flip:
|
||||
base_name += '_flip'
|
||||
return os.path.join(data_dir, base_name)
|
||||
|
||||
|
||||
def main(args):
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
|
||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||
image_paths = train_util.glob_images(args.train_data_dir)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
if os.path.exists(args.in_json):
|
||||
@@ -48,6 +63,25 @@ def main(args):
|
||||
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
||||
return
|
||||
|
||||
# 既に存在するファイルをfilterする
|
||||
if args.skip_existing:
|
||||
filtered = []
|
||||
for image_path in image_paths:
|
||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||
|
||||
npz_file_name_flip = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"
|
||||
if os.path.exists(npz_file_name_flip):
|
||||
if not args.flip_aug:
|
||||
continue
|
||||
|
||||
npz_file_name_flip = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
|
||||
if os.path.exists(npz_file_name_flip):
|
||||
continue
|
||||
|
||||
filtered.apppend(image_path)
|
||||
print(f"number of skipped images (npz already exists) / npzファイルが存在するためスキップした画像数: {len(image_paths) - len(filtered)}")
|
||||
image_paths = filtered
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
@@ -70,15 +104,55 @@ def main(args):
|
||||
buckets_imgs = [[] for _ in range(len(bucket_resos))]
|
||||
bucket_counts = [0 for _ in range(len(bucket_resos))]
|
||||
img_ar_errors = []
|
||||
for i, image_path in enumerate(tqdm(image_paths, smoothing=0.0)):
|
||||
|
||||
def process_batch(is_last):
|
||||
for j in range(len(buckets_imgs)):
|
||||
bucket = buckets_imgs[j]
|
||||
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
||||
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
|
||||
|
||||
for (image_key, _, _), latent in zip(bucket, latents):
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
|
||||
np.savez(npz_file_name, latent)
|
||||
|
||||
# flip
|
||||
if args.flip_aug:
|
||||
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
||||
|
||||
for (image_key, _, _), latent in zip(bucket, latents):
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
|
||||
np.savez(npz_file_name, latent)
|
||||
|
||||
bucket.clear()
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = train_util.ImageLoadingDataset(image_paths)
|
||||
data = torch.util.data.DataLoader(dataset, batch_size=1, shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
if data_entry[0] is None:
|
||||
continue
|
||||
|
||||
img_tensor, image_path = data_entry[0]
|
||||
if img_tensor is not None:
|
||||
image = transforms.functional.to_pil_image(img_tensor)
|
||||
else:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert("RGB")
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
|
||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
|
||||
image = Image.open(image_path)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert("RGB")
|
||||
|
||||
aspect_ratio = image.width / image.height
|
||||
ar_errors = bucket_aspect_ratios - aspect_ratio
|
||||
bucket_id = np.abs(ar_errors).argmin()
|
||||
@@ -123,25 +197,10 @@ def main(args):
|
||||
metadata[image_key]['train_resolution'] = reso
|
||||
|
||||
# バッチを推論するか判定して推論する
|
||||
is_last = i == len(image_paths) - 1
|
||||
for j in range(len(buckets_imgs)):
|
||||
bucket = buckets_imgs[j]
|
||||
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
||||
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
|
||||
process_batch(False)
|
||||
|
||||
for (image_key, reso, _), latent in zip(bucket, latents):
|
||||
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key
|
||||
np.savez(os.path.join(args.train_data_dir, npz_file_name), latent)
|
||||
|
||||
# flip
|
||||
if args.flip_aug:
|
||||
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
||||
|
||||
for (image_key, reso, _), latent in zip(bucket, latents):
|
||||
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key
|
||||
np.savez(os.path.join(args.train_data_dir, npz_file_name + '_flip'), latent)
|
||||
|
||||
bucket.clear()
|
||||
# 残りを処理する
|
||||
process_batch(True)
|
||||
|
||||
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
|
||||
print(f"bucket {i} {reso}: {count}")
|
||||
@@ -162,8 +221,10 @@ if __name__ == '__main__':
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
|
||||
help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)')
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
||||
parser.add_argument("--max_resolution", type=str, default="512,512",
|
||||
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
|
||||
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||
@@ -174,6 +235,8 @@ if __name__ == '__main__':
|
||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
||||
parser.add_argument("--flip_aug", action="store_true",
|
||||
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
|
||||
parser.add_argument("--skip_existing", action="store_true",
|
||||
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user