mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add scripts.
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
logs
|
||||
__pycache__
|
||||
wd14_tagger_model
|
||||
12
README-ja.md
12
README-ja.md
@@ -3,9 +3,21 @@ Stable Diffusionの学習、画像生成、その他のスクリプトを入れ
|
||||
|
||||
[README in English](./README.md)
|
||||
|
||||
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。
|
||||
|
||||
以下のスクリプトがあります。
|
||||
|
||||
* DreamBooth、U-NetおよびText Encoderの学習をサポート
|
||||
* fine-tuning、同上
|
||||
* 画像生成
|
||||
* モデル変換(Stable Diffision ckpt/safetensorsとDiffusersの相互変換)
|
||||
|
||||
## 使用法について
|
||||
|
||||
note.comに記事がありますのでそちらをご覧ください(将来的にはこちらへ移すかもしれません)。
|
||||
|
||||
* [環境整備とDreamBooth学習スクリプトについて](https://note.com/kohya_ss/n/nee3ed1649fb6)
|
||||
* [fine-tuningスクリプト](https://note.com/kohya_ss/n/nbf7ce8d80f29):
|
||||
Including BLIP captioning and tagging by DeepDanbooru or WD14 tagger
|
||||
* [画像生成スクリプト](https://note.com/kohya_ss/n/n2693183a798e)
|
||||
* [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
|
||||
|
||||
20
README.md
20
README.md
@@ -2,11 +2,27 @@ This repository contains training, generation and utility scripts for Stable Dif
|
||||
|
||||
[日本語版README](./README-ja.md)
|
||||
|
||||
This repository currently contains:
|
||||
For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais!
|
||||
|
||||
This repository contains the scripts for:
|
||||
|
||||
* DreamBooth training, including U-Net and Text Encoder
|
||||
* fine-tuning (native training), including U-Net and Text Encoder
|
||||
* image generation
|
||||
* model conversion (Stable Diffision ckpt/safetensors and Diffusers)
|
||||
* model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
|
||||
|
||||
## About requirements_*.txt
|
||||
|
||||
These files do not contain requirements for PyTorch and Diffusers. Because the versions of them depend on your environment. Please install PyTorch at first, then Diffusers.
|
||||
|
||||
The scripts is tested with PyTorch 1.12.1 and 1.13.0, Diffusers 0.10.2.
|
||||
|
||||
## Links to how-to-use documents
|
||||
|
||||
All documents are in Japanese currently, and CUI based.
|
||||
|
||||
* [Environment setup and DreamBooth training guide](https://note.com/kohya_ss/n/nee3ed1649fb6)
|
||||
* [Fine-tuning step-by-step guide](https://note.com/kohya_ss/n/nbf7ce8d80f29):
|
||||
Including BLIP captioning and tagging by DeepDanbooru or WD14 tagger
|
||||
* [Image generation](https://note.com/kohya_ss/n/n2693183a798e)
|
||||
* [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
|
||||
|
||||
123
clean_captions_and_tags.py
Normal file
123
clean_captions_and_tags.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||
# (c) 2022 Kohya S. @kohya_ss
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def clean_tags(image_key, tags):
|
||||
# replace '_' to ' '
|
||||
tags = tags.replace('_', ' ')
|
||||
|
||||
# remove rating: deepdanbooruのみ
|
||||
tokens = tags.split(", rating")
|
||||
if len(tokens) == 1:
|
||||
# WD14 taggerのときはこちらになるのでメッセージは出さない
|
||||
# print("no rating:")
|
||||
# print(f"{image_key} {tags}")
|
||||
pass
|
||||
else:
|
||||
if len(tokens) > 2:
|
||||
print("multiple ratings:")
|
||||
print(f"{image_key} {tags}")
|
||||
tags = tokens[0]
|
||||
|
||||
return tags
|
||||
|
||||
|
||||
# 上から順に検索、置換される
|
||||
# ('置換元文字列', '置換後文字列')
|
||||
CAPTION_REPLACEMENTS = [
|
||||
('anime anime', 'anime'),
|
||||
('young ', ''),
|
||||
('anime girl', 'girl'),
|
||||
('cartoon female', 'girl'),
|
||||
('cartoon lady', 'girl'),
|
||||
('cartoon character', 'girl'), # a or ~s
|
||||
('cartoon woman', 'girl'),
|
||||
('cartoon women', 'girls'),
|
||||
('cartoon girl', 'girl'),
|
||||
('anime female', 'girl'),
|
||||
('anime lady', 'girl'),
|
||||
('anime character', 'girl'), # a or ~s
|
||||
('anime woman', 'girl'),
|
||||
('anime women', 'girls'),
|
||||
('lady', 'girl'),
|
||||
('female', 'girl'),
|
||||
('woman', 'girl'),
|
||||
('women', 'girls'),
|
||||
('people', 'girls'),
|
||||
('person', 'girl'),
|
||||
('a cartoon figure', 'a figure'),
|
||||
('a cartoon image', 'an image'),
|
||||
('a cartoon picture', 'a picture'),
|
||||
('an anime cartoon image', 'an image'),
|
||||
('a cartoon anime drawing', 'a drawing'),
|
||||
('a cartoon drawing', 'a drawing'),
|
||||
('girl girl', 'girl'),
|
||||
]
|
||||
|
||||
|
||||
def clean_caption(caption):
|
||||
for rf, rt in CAPTION_REPLACEMENTS:
|
||||
replaced = True
|
||||
while replaced:
|
||||
bef = caption
|
||||
caption = caption.replace(rf, rt)
|
||||
replaced = bef != caption
|
||||
return caption
|
||||
|
||||
|
||||
def main(args):
|
||||
if os.path.exists(args.in_json):
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
print("no metadata / メタデータファイルがありません")
|
||||
return
|
||||
|
||||
print("cleaning captions and tags.")
|
||||
image_keys = list(metadata.keys())
|
||||
for image_key in tqdm(image_keys):
|
||||
tags = metadata[image_key].get('tags')
|
||||
if tags is None:
|
||||
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
|
||||
else:
|
||||
metadata[image_key]['tags'] = clean_tags(image_key, tags)
|
||||
|
||||
caption = metadata[image_key].get('caption')
|
||||
if caption is None:
|
||||
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
|
||||
else:
|
||||
metadata[image_key]['caption'] = clean_caption(caption)
|
||||
|
||||
# metadataを書き出して終わり
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
# parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
if len(unknown) == 1:
|
||||
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
|
||||
print("All captions and tags in the metadata are processed.")
|
||||
print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
|
||||
print("メタデータ内のすべてのキャプションとタグが処理されます。")
|
||||
args.in_json = args.out_json
|
||||
args.out_json = unknown[0]
|
||||
elif len(unknown) > 0:
|
||||
raise ValueError(f"error: unrecognized arguments: {unknown}")
|
||||
|
||||
main(args)
|
||||
93
convert_diffusers20_original_sd.py
Normal file
93
convert_diffusers20_original_sd.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# convert Diffusers v1.x/v2.0 model to original Stable Diffusion
|
||||
# v1: initial version
|
||||
# v2: support safetensors
|
||||
# v3: fix to support another format
|
||||
# v4: support safetensors in Diffusers
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
import model_util
|
||||
|
||||
|
||||
def convert(args):
|
||||
# 引数を確認する
|
||||
load_dtype = torch.float16 if args.fp16 else None
|
||||
|
||||
save_dtype = None
|
||||
if args.fp16:
|
||||
save_dtype = torch.float16
|
||||
elif args.bf16:
|
||||
save_dtype = torch.bfloat16
|
||||
elif args.float:
|
||||
save_dtype = torch.float
|
||||
|
||||
is_load_ckpt = os.path.isfile(args.model_to_load)
|
||||
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
||||
|
||||
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
||||
assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
||||
|
||||
# モデルを読み込む
|
||||
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
||||
print(f"loading {msg}: {args.model_to_load}")
|
||||
|
||||
if is_load_ckpt:
|
||||
v2_model = args.v2
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None)
|
||||
text_encoder = pipe.text_encoder
|
||||
vae = pipe.vae
|
||||
unet = pipe.unet
|
||||
|
||||
if args.v1 == args.v2:
|
||||
# 自動判定する
|
||||
v2_model = unet.config.cross_attention_dim == 1024
|
||||
print("checking model version: model is " + ('v2' if v2_model else 'v1'))
|
||||
else:
|
||||
v2_model = args.v1
|
||||
|
||||
# 変換して保存する
|
||||
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
|
||||
print(f"converting and saving as {msg}: {args.model_to_save}")
|
||||
|
||||
if is_save_ckpt:
|
||||
original_model = args.model_to_load if is_load_ckpt else None
|
||||
key_count = model_util.save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet,
|
||||
original_model, args.epoch, args.global_step, save_dtype, vae)
|
||||
print(f"model saved. total converted state_dict keys: {key_count}")
|
||||
else:
|
||||
print(f"copy scheduler/tokenizer config from: {args.reference_model}")
|
||||
model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors)
|
||||
print(f"model saved.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v1", action='store_true',
|
||||
help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')
|
||||
parser.add_argument("--fp16", action='store_true',
|
||||
help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)')
|
||||
parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)')
|
||||
parser.add_argument("--float", action='store_true',
|
||||
help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)')
|
||||
parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')
|
||||
parser.add_argument("--global_step", type=int, default=0,
|
||||
help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
|
||||
parser.add_argument("--reference_model", type=str, default=None,
|
||||
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
|
||||
parser.add_argument("--use_safetensors", action='store_true',
|
||||
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)")
|
||||
|
||||
parser.add_argument("model_to_load", type=str, default=None,
|
||||
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
|
||||
parser.add_argument("model_to_save", type=str, default=None,
|
||||
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
|
||||
|
||||
args = parser.parse_args()
|
||||
convert(args)
|
||||
239
detect_face_rotate.py
Normal file
239
detect_face_rotate.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
|
||||
# (c) 2022 Kohya S. @kohya_ss
|
||||
|
||||
# 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す
|
||||
|
||||
# v2: extract max face if multiple faces are found
|
||||
# v3: add crop_ratio option
|
||||
# v4: add multple faces extraction and min/max size
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import cv2
|
||||
import glob
|
||||
import os
|
||||
from anime_face_detector import create_detector
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
KP_REYE = 11
|
||||
KP_LEYE = 19
|
||||
|
||||
SCORE_THRES = 0.90
|
||||
|
||||
|
||||
def detect_faces(detector, image, min_size):
|
||||
preds = detector(image) # bgr
|
||||
# print(len(preds))
|
||||
|
||||
faces = []
|
||||
for pred in preds:
|
||||
bb = pred['bbox']
|
||||
score = bb[-1]
|
||||
if score < SCORE_THRES:
|
||||
continue
|
||||
|
||||
left, top, right, bottom = bb[:4]
|
||||
cx = int((left + right) / 2)
|
||||
cy = int((top + bottom) / 2)
|
||||
fw = int(right - left)
|
||||
fh = int(bottom - top)
|
||||
|
||||
lex, ley = pred['keypoints'][KP_LEYE, 0:2]
|
||||
rex, rey = pred['keypoints'][KP_REYE, 0:2]
|
||||
angle = math.atan2(ley - rey, lex - rex)
|
||||
angle = angle / math.pi * 180
|
||||
|
||||
faces.append((cx, cy, fw, fh, angle))
|
||||
|
||||
faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 大きい順
|
||||
return faces
|
||||
|
||||
|
||||
def rotate_image(image, angle, cx, cy):
|
||||
h, w = image.shape[0:2]
|
||||
rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
|
||||
|
||||
# # 回転する分、すこし画像サイズを大きくする→とりあえず無効化
|
||||
# nh = max(h, int(w * math.sin(angle)))
|
||||
# nw = max(w, int(h * math.sin(angle)))
|
||||
# if nh > h or nw > w:
|
||||
# pad_y = nh - h
|
||||
# pad_t = pad_y // 2
|
||||
# pad_x = nw - w
|
||||
# pad_l = pad_x // 2
|
||||
# m = np.array([[0, 0, pad_l],
|
||||
# [0, 0, pad_t]])
|
||||
# rot_mat = rot_mat + m
|
||||
# h, w = nh, nw
|
||||
# cx += pad_l
|
||||
# cy += pad_t
|
||||
|
||||
result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
|
||||
return result, cx, cy
|
||||
|
||||
|
||||
def process(args):
|
||||
assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません"
|
||||
assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
|
||||
|
||||
# アニメ顔検出モデルを読み込む
|
||||
print("loading face detector.")
|
||||
detector = create_detector('yolov3')
|
||||
|
||||
# cropの引数を解析する
|
||||
if args.crop_size is None:
|
||||
crop_width = crop_height = None
|
||||
else:
|
||||
tokens = args.crop_size.split(',')
|
||||
assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください"
|
||||
crop_width, crop_height = [int(t) for t in tokens]
|
||||
|
||||
if args.crop_ratio is None:
|
||||
crop_h_ratio = crop_v_ratio = None
|
||||
else:
|
||||
tokens = args.crop_ratio.split(',')
|
||||
assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください"
|
||||
crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
|
||||
|
||||
# 画像を処理する
|
||||
print("processing.")
|
||||
output_extension = ".png"
|
||||
|
||||
os.makedirs(args.dst_dir, exist_ok=True)
|
||||
paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \
|
||||
glob.glob(os.path.join(args.src_dir, "*.webp"))
|
||||
for path in tqdm(paths):
|
||||
basename = os.path.splitext(os.path.basename(path))[0]
|
||||
|
||||
# image = cv2.imread(path) # 日本語ファイル名でエラーになる
|
||||
image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
||||
if image.shape[2] == 4:
|
||||
print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
|
||||
image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
|
||||
|
||||
h, w = image.shape[:2]
|
||||
|
||||
faces = detect_faces(detector, image, args.multiple_faces)
|
||||
for i, face in enumerate(faces):
|
||||
cx, cy, fw, fh, angle = face
|
||||
face_size = max(fw, fh)
|
||||
if args.min_size is not None and face_size < args.min_size:
|
||||
continue
|
||||
if args.max_size is not None and face_size >= args.max_size:
|
||||
continue
|
||||
face_suffix = f"_{i+1:02d}" if args.multiple_faces else ""
|
||||
|
||||
# オプション指定があれば回転する
|
||||
face_img = image
|
||||
if args.rotate:
|
||||
face_img, cx, cy = rotate_image(face_img, angle, cx, cy)
|
||||
|
||||
# オプション指定があれば顔を中心に切り出す
|
||||
if crop_width is not None or crop_h_ratio is not None:
|
||||
cur_crop_width, cur_crop_height = crop_width, crop_height
|
||||
if crop_h_ratio is not None:
|
||||
cur_crop_width = int(face_size * crop_h_ratio + .5)
|
||||
cur_crop_height = int(face_size * crop_v_ratio + .5)
|
||||
|
||||
# リサイズを必要なら行う
|
||||
scale = 1.0
|
||||
if args.resize_face_size is not None:
|
||||
# 顔サイズを基準にリサイズする
|
||||
scale = args.resize_face_size / face_size
|
||||
if scale < cur_crop_width / w:
|
||||
print(
|
||||
f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
||||
scale = cur_crop_width / w
|
||||
if scale < cur_crop_height / h:
|
||||
print(
|
||||
f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
||||
scale = cur_crop_height / h
|
||||
elif crop_h_ratio is not None:
|
||||
# 倍率指定の時にはリサイズしない
|
||||
pass
|
||||
else:
|
||||
# 切り出しサイズ指定あり
|
||||
if w < cur_crop_width:
|
||||
print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
|
||||
scale = cur_crop_width / w
|
||||
if h < cur_crop_height:
|
||||
print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
|
||||
scale = cur_crop_height / h
|
||||
if args.resize_fit:
|
||||
scale = max(cur_crop_width / w, cur_crop_height / h)
|
||||
|
||||
if scale != 1.0:
|
||||
w = int(w * scale + .5)
|
||||
h = int(h * scale + .5)
|
||||
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4)
|
||||
cx = int(cx * scale + .5)
|
||||
cy = int(cy * scale + .5)
|
||||
fw = int(fw * scale + .5)
|
||||
fh = int(fh * scale + .5)
|
||||
|
||||
cur_crop_width = min(cur_crop_width, face_img.shape[1])
|
||||
cur_crop_height = min(cur_crop_height, face_img.shape[0])
|
||||
|
||||
x = cx - cur_crop_width // 2
|
||||
cx = cur_crop_width // 2
|
||||
if x < 0:
|
||||
cx = cx + x
|
||||
x = 0
|
||||
elif x + cur_crop_width > w:
|
||||
cx = cx + (x + cur_crop_width - w)
|
||||
x = w - cur_crop_width
|
||||
face_img = face_img[:, x:x+cur_crop_width]
|
||||
|
||||
y = cy - cur_crop_height // 2
|
||||
cy = cur_crop_height // 2
|
||||
if y < 0:
|
||||
cy = cy + y
|
||||
y = 0
|
||||
elif y + cur_crop_height > h:
|
||||
cy = cy + (y + cur_crop_height - h)
|
||||
y = h - cur_crop_height
|
||||
face_img = face_img[y:y + cur_crop_height]
|
||||
|
||||
# # debug
|
||||
# print(path, cx, cy, angle)
|
||||
# crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
|
||||
# cv2.imshow("image", crp)
|
||||
# if cv2.waitKey() == 27:
|
||||
# break
|
||||
# cv2.destroyAllWindows()
|
||||
|
||||
# debug
|
||||
if args.debug:
|
||||
cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20)
|
||||
|
||||
_, buf = cv2.imencode(output_extension, face_img)
|
||||
with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f:
|
||||
buf.tofile(f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ")
|
||||
parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ")
|
||||
parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する")
|
||||
parser.add_argument("--resize_fit", action="store_true",
|
||||
help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする")
|
||||
parser.add_argument("--resize_face_size", type=int, default=None,
|
||||
help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする")
|
||||
parser.add_argument("--crop_size", type=str, default=None,
|
||||
help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す")
|
||||
parser.add_argument("--crop_ratio", type=str, default=None,
|
||||
help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す")
|
||||
parser.add_argument("--min_size", type=int, default=None,
|
||||
help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)")
|
||||
parser.add_argument("--max_size", type=int, default=None,
|
||||
help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)")
|
||||
parser.add_argument("--multiple_faces", action="store_true",
|
||||
help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す")
|
||||
parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
|
||||
args = parser.parse_args()
|
||||
|
||||
process(args)
|
||||
1059
fine_tune.py
Normal file
1059
fine_tune.py
Normal file
File diff suppressed because it is too large
Load Diff
2517
gen_img_diffusers.py
Normal file
2517
gen_img_diffusers.py
Normal file
File diff suppressed because it is too large
Load Diff
96
hypernetwork_nai.py
Normal file
96
hypernetwork_nai.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# NAI compatible
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
def __init__(self, dim, multiplier=1.0):
|
||||
super().__init__()
|
||||
|
||||
linear1 = torch.nn.Linear(dim, dim * 2)
|
||||
linear2 = torch.nn.Linear(dim * 2, dim)
|
||||
linear1.weight.data.normal_(mean=0.0, std=0.01)
|
||||
linear1.bias.data.zero_()
|
||||
linear2.weight.data.normal_(mean=0.0, std=0.01)
|
||||
linear2.bias.data.zero_()
|
||||
linears = [linear1, linear2]
|
||||
|
||||
self.linear = torch.nn.Sequential(*linears)
|
||||
self.multiplier = multiplier
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.linear(x) * self.multiplier
|
||||
|
||||
|
||||
class Hypernetwork(torch.nn.Module):
|
||||
enable_sizes = [320, 640, 768, 1280]
|
||||
# return self.modules[Hypernetwork.enable_sizes.index(size)]
|
||||
|
||||
def __init__(self, multiplier=1.0) -> None:
|
||||
super().__init__()
|
||||
self.modules = []
|
||||
for size in Hypernetwork.enable_sizes:
|
||||
self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
|
||||
self.register_module(f"{size}_0", self.modules[-1][0])
|
||||
self.register_module(f"{size}_1", self.modules[-1][1])
|
||||
|
||||
def apply_to_stable_diffusion(self, text_encoder, vae, unet):
|
||||
blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
|
||||
for block in blocks:
|
||||
for subblk in block:
|
||||
if 'SpatialTransformer' in str(type(subblk)):
|
||||
for tf_block in subblk.transformer_blocks:
|
||||
for attn in [tf_block.attn1, tf_block.attn2]:
|
||||
size = attn.context_dim
|
||||
if size in Hypernetwork.enable_sizes:
|
||||
attn.hypernetwork = self
|
||||
else:
|
||||
attn.hypernetwork = None
|
||||
|
||||
def apply_to_diffusers(self, text_encoder, vae, unet):
|
||||
blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
|
||||
for block in blocks:
|
||||
if hasattr(block, 'attentions'):
|
||||
for subblk in block.attentions:
|
||||
if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~
|
||||
for tf_block in subblk.transformer_blocks:
|
||||
for attn in [tf_block.attn1, tf_block.attn2]:
|
||||
size = attn.to_k.in_features
|
||||
if size in Hypernetwork.enable_sizes:
|
||||
attn.hypernetwork = self
|
||||
else:
|
||||
attn.hypernetwork = None
|
||||
return True # TODO error checking
|
||||
|
||||
def forward(self, x, context):
|
||||
size = context.shape[-1]
|
||||
assert size in Hypernetwork.enable_sizes
|
||||
module = self.modules[Hypernetwork.enable_sizes.index(size)]
|
||||
return module[0].forward(context), module[1].forward(context)
|
||||
|
||||
def load_from_state_dict(self, state_dict):
|
||||
# old ver to new ver
|
||||
changes = {
|
||||
'linear1.bias': 'linear.0.bias',
|
||||
'linear1.weight': 'linear.0.weight',
|
||||
'linear2.bias': 'linear.1.bias',
|
||||
'linear2.weight': 'linear.1.weight',
|
||||
}
|
||||
for key_from, key_to in changes.items():
|
||||
if key_from in state_dict:
|
||||
state_dict[key_to] = state_dict[key_from]
|
||||
del state_dict[key_from]
|
||||
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
|
||||
self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
|
||||
return True
|
||||
|
||||
def get_state_dict(self):
|
||||
state_dict = {}
|
||||
for i, size in enumerate(Hypernetwork.enable_sizes):
|
||||
sd0 = self.modules[i][0].state_dict()
|
||||
sd1 = self.modules[i][1].state_dict()
|
||||
state_dict[size] = [sd0, sd1]
|
||||
return state_dict
|
||||
98
make_captions.py
Normal file
98
make_captions.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||
# (c) 2022 Kohya S. @kohya_ss
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import json
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from models.blip import blip_decoder
|
||||
# from Salesforce_BLIP.models.blip import blip_decoder
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
def main(args):
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
print(f"loading BLIP caption: {args.caption_weights}")
|
||||
image_size = 384
|
||||
model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large')
|
||||
model.eval()
|
||||
model = model.to(DEVICE)
|
||||
print("BLIP loaded")
|
||||
|
||||
# 正方形でいいのか? という気がするがソースがそうなので
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
])
|
||||
|
||||
# captioningする
|
||||
def run_batch(path_imgs):
|
||||
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
if args.beam_search:
|
||||
captions = model.generate(imgs, sample=False, num_beams=args.num_beams,
|
||||
max_length=args.max_length, min_length=args.min_length)
|
||||
else:
|
||||
captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
|
||||
|
||||
for (image_path, _), caption in zip(path_imgs, captions):
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
||||
f.write(caption + "\n")
|
||||
if args.debug:
|
||||
print(image_path, caption)
|
||||
|
||||
b_imgs = []
|
||||
for image_path in tqdm(image_paths, smoothing=0.0):
|
||||
raw_image = Image.open(image_path)
|
||||
if raw_image.mode != "RGB":
|
||||
print(f"convert image mode {raw_image.mode} to RGB: {image_path}")
|
||||
raw_image = raw_image.convert("RGB")
|
||||
|
||||
image = transform(raw_image)
|
||||
b_imgs.append((image_path, image))
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("caption_weights", type=str,
|
||||
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
|
||||
parser.add_argument("--caption_extention", type=str, default=None,
|
||||
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument("--beam_search", action="store_true",
|
||||
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
|
||||
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
|
||||
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
|
||||
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
if args.caption_extention is not None:
|
||||
args.caption_extension = args.caption_extention
|
||||
|
||||
main(args)
|
||||
68
merge_captions_to_metadata.py
Normal file
68
merge_captions_to_metadata.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||
# (c) 2022 Kohya S. @kohya_ss
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main(args):
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
if args.in_json is None and os.path.isfile(args.out_json):
|
||||
args.in_json = args.out_json
|
||||
|
||||
if args.in_json is not None:
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
|
||||
else:
|
||||
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||
metadata = {}
|
||||
|
||||
print("merge caption texts to metadata json.")
|
||||
for image_path in tqdm(image_paths):
|
||||
caption_path = os.path.splitext(image_path)[0] + args.caption_extension
|
||||
with open(caption_path, "rt", encoding='utf-8') as f:
|
||||
caption = f.readlines()[0].strip()
|
||||
|
||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
|
||||
metadata[image_key]['caption'] = caption
|
||||
if args.debug:
|
||||
print(image_key, caption)
|
||||
|
||||
# metadataを書き出して終わり
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
|
||||
parser.add_argument("--caption_extention", type=str, default=None,
|
||||
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
|
||||
parser.add_argument("--full_path", action="store_true",
|
||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
if args.caption_extention is not None:
|
||||
args.caption_extension = args.caption_extention
|
||||
|
||||
main(args)
|
||||
60
merge_dd_tags_to_metadata.py
Normal file
60
merge_dd_tags_to_metadata.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||
# (c) 2022 Kohya S. @kohya_ss
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main(args):
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
if args.in_json is None and os.path.isfile(args.out_json):
|
||||
args.in_json = args.out_json
|
||||
|
||||
if args.in_json is not None:
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
|
||||
else:
|
||||
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||
metadata = {}
|
||||
|
||||
print("merge tags to metadata json.")
|
||||
for image_path in tqdm(image_paths):
|
||||
tags_path = os.path.splitext(image_path)[0] + '.txt'
|
||||
with open(tags_path, "rt", encoding='utf-8') as f:
|
||||
tags = f.readlines()[0].strip()
|
||||
|
||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
|
||||
metadata[image_key]['tags'] = tags
|
||||
if args.debug:
|
||||
print(image_key, tags)
|
||||
|
||||
# metadataを書き出して終わり
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
|
||||
parser.add_argument("--full_path", action="store_true",
|
||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
1182
model_util.py
Normal file
1182
model_util.py
Normal file
File diff suppressed because it is too large
Load Diff
177
prepare_buckets_latents.py
Normal file
177
prepare_buckets_latents.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||
# (c) 2022 Kohya S. @kohya_ss
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKL
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
import model_util
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_latents(vae, images, weight_dtype):
|
||||
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
|
||||
img_tensors = torch.stack(img_tensors)
|
||||
img_tensors = img_tensors.to(DEVICE, weight_dtype)
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
|
||||
return latents
|
||||
|
||||
|
||||
def main(args):
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
if os.path.exists(args.in_json):
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
||||
return
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
|
||||
vae.eval()
|
||||
vae.to(DEVICE, dtype=weight_dtype)
|
||||
|
||||
# bucketのサイズを計算する
|
||||
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
|
||||
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
|
||||
bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
||||
max_reso, args.min_bucket_reso, args.max_bucket_reso)
|
||||
|
||||
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
|
||||
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
|
||||
buckets_imgs = [[] for _ in range(len(bucket_resos))]
|
||||
bucket_counts = [0 for _ in range(len(bucket_resos))]
|
||||
img_ar_errors = []
|
||||
for i, image_path in enumerate(tqdm(image_paths, smoothing=0.0)):
|
||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
|
||||
image = Image.open(image_path)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert("RGB")
|
||||
|
||||
aspect_ratio = image.width / image.height
|
||||
ar_errors = bucket_aspect_ratios - aspect_ratio
|
||||
bucket_id = np.abs(ar_errors).argmin()
|
||||
reso = bucket_resos[bucket_id]
|
||||
ar_error = ar_errors[bucket_id]
|
||||
img_ar_errors.append(abs(ar_error))
|
||||
|
||||
# どのサイズにリサイズするか→トリミングする方向で
|
||||
if ar_error <= 0: # 横が長い→縦を合わせる
|
||||
scale = reso[1] / image.height
|
||||
else:
|
||||
scale = reso[0] / image.width
|
||||
|
||||
resized_size = (int(image.width * scale + .5), int(image.height * scale + .5))
|
||||
|
||||
# print(image.width, image.height, bucket_id, bucket_resos[bucket_id], ar_errors[bucket_id], resized_size,
|
||||
# bucket_resos[bucket_id][0] - resized_size[0], bucket_resos[bucket_id][1] - resized_size[1])
|
||||
|
||||
assert resized_size[0] == reso[0] or resized_size[1] == reso[
|
||||
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
||||
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||
|
||||
# 画像をリサイズしてトリミングする
|
||||
# PILにinter_areaがないのでcv2で……
|
||||
image = np.array(image)
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
|
||||
if resized_size[0] > reso[0]:
|
||||
trim_size = resized_size[0] - reso[0]
|
||||
image = image[:, trim_size//2:trim_size//2 + reso[0]]
|
||||
elif resized_size[1] > reso[1]:
|
||||
trim_size = resized_size[1] - reso[1]
|
||||
image = image[trim_size//2:trim_size//2 + reso[1]]
|
||||
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||
|
||||
# # debug
|
||||
# cv2.imwrite(f"r:\\test\\img_{i:05d}.jpg", image[:, :, ::-1])
|
||||
|
||||
# バッチへ追加
|
||||
buckets_imgs[bucket_id].append((image_key, reso, image))
|
||||
bucket_counts[bucket_id] += 1
|
||||
metadata[image_key]['train_resolution'] = reso
|
||||
|
||||
# バッチを推論するか判定して推論する
|
||||
is_last = i == len(image_paths) - 1
|
||||
for j in range(len(buckets_imgs)):
|
||||
bucket = buckets_imgs[j]
|
||||
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
||||
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
|
||||
|
||||
for (image_key, reso, _), latent in zip(bucket, latents):
|
||||
np.savez(os.path.join(args.train_data_dir, os.path.splitext(os.path.basename(image_key))[0]), latent)
|
||||
|
||||
# flip
|
||||
if args.flip_aug:
|
||||
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
||||
|
||||
for (image_key, reso, _), latent in zip(bucket, latents):
|
||||
np.savez(os.path.join(args.train_data_dir, os.path.splitext(os.path.basename(image_key))[0] + '_flip'), latent)
|
||||
|
||||
bucket.clear()
|
||||
|
||||
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
|
||||
print(f"bucket {i} {reso}: {count}")
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
print(f"mean ar error: {np.mean(img_ar_errors)}")
|
||||
|
||||
# metadataを書き出して終わり
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument("--max_resolution", type=str, default="512,512",
|
||||
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
|
||||
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--mixed_precision", type=str, default="no",
|
||||
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
||||
parser.add_argument("--full_path", action="store_true",
|
||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
||||
parser.add_argument("--flip_aug", action="store_true",
|
||||
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
3
requirements_blip.txt
Normal file
3
requirements_blip.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
timm==0.4.12
|
||||
transformers==4.16.2
|
||||
fairscale==0.4.4
|
||||
8
requirements_db_finetune.txt
Normal file
8
requirements_db_finetune.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
accelerate==0.14.0
|
||||
transformers>=4.21.0
|
||||
ftfy
|
||||
albumentations
|
||||
opencv-python
|
||||
einops
|
||||
pytorch_lightning
|
||||
safetensors
|
||||
2
requirements_wd14_tagger.txt
Normal file
2
requirements_wd14_tagger.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
tensorflow<2.11
|
||||
huggingface-hub
|
||||
143
tag_images_by_wd14_tagger.py
Normal file
143
tag_images_by_wd14_tagger.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||
# (c) 2022 Kohya S. @kohya_ss
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import glob
|
||||
import os
|
||||
|
||||
from PIL import Image
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from tensorflow.keras.models import load_model
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
# from wd14 tagger
|
||||
IMAGE_SIZE = 448
|
||||
|
||||
WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-vit-tagger'
|
||||
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
||||
SUB_DIR = "variables"
|
||||
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
||||
CSV_FILE = FILES[-1]
|
||||
|
||||
|
||||
def main(args):
|
||||
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
|
||||
# depreacatedの警告が出るけどなくなったらその時
|
||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||
if not os.path.exists(args.model_dir) or args.force_download:
|
||||
print("downloading wd14 tagger model from hf_hub")
|
||||
for file in FILES:
|
||||
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
|
||||
args.model_dir, SUB_DIR), force_download=True, force_filename=file)
|
||||
|
||||
# 画像を読み込む
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
print("loading model and labels")
|
||||
model = load_model(args.model_dir)
|
||||
|
||||
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
||||
# 依存ライブラリを増やしたくないので自力で読むよ
|
||||
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
l = [row for row in reader]
|
||||
header = l[0] # tag_id,name,category,count
|
||||
rows = l[1:]
|
||||
assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
|
||||
|
||||
tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ
|
||||
|
||||
# 推論する
|
||||
def run_batch(path_imgs):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
|
||||
for (image_path, _), prob in zip(path_imgs, probs):
|
||||
# 最初の4つはratingなので無視する
|
||||
# # First 4 labels are actually ratings: pick one with argmax
|
||||
# ratings_names = label_names[:4]
|
||||
# rating_index = ratings_names["probs"].argmax()
|
||||
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
|
||||
|
||||
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
|
||||
# Everything else is tags: pick any where prediction confidence > threshold
|
||||
tag_text = ""
|
||||
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
|
||||
if p >= args.thresh:
|
||||
tag_text += ", " + tags[i]
|
||||
|
||||
if len(tag_text) > 0:
|
||||
tag_text = tag_text[2:] # 最初の ", " を消す
|
||||
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
||||
f.write(tag_text + '\n')
|
||||
if args.debug:
|
||||
print(image_path, tag_text)
|
||||
|
||||
b_imgs = []
|
||||
for image_path in tqdm(image_paths, smoothing=0.0):
|
||||
img = Image.open(image_path) # cv2は日本語ファイル名で死ぬのとモード変換したいのでpillowで開く
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert("RGB")
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1] # RGB->BGR
|
||||
|
||||
# pad to square
|
||||
size = max(img.shape[0:2])
|
||||
pad_x = size - img.shape[1]
|
||||
pad_y = size - img.shape[0]
|
||||
pad_l = pad_x // 2
|
||||
pad_t = pad_y // 2
|
||||
img = np.pad(img, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
|
||||
|
||||
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
||||
img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
|
||||
# cv2.imshow("img", img)
|
||||
# cv2.waitKey()
|
||||
# cv2.destroyAllWindows()
|
||||
|
||||
img = img.astype(np.float32)
|
||||
b_imgs.append((image_path, img))
|
||||
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--repo_id", type=str, default=WD14_TAGGER_REPO,
|
||||
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
|
||||
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
|
||||
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
|
||||
parser.add_argument("--force_download", action='store_true',
|
||||
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
|
||||
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument("--caption_extention", type=str, default=None,
|
||||
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
if args.caption_extention is not None:
|
||||
args.caption_extension = args.caption_extention
|
||||
|
||||
main(args)
|
||||
1228
train_db_fixed.py
Normal file
1228
train_db_fixed.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user