mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add scripts.
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
logs
|
||||||
|
__pycache__
|
||||||
|
wd14_tagger_model
|
||||||
12
README-ja.md
12
README-ja.md
@@ -3,9 +3,21 @@ Stable Diffusionの学習、画像生成、その他のスクリプトを入れ
|
|||||||
|
|
||||||
[README in English](./README.md)
|
[README in English](./README.md)
|
||||||
|
|
||||||
|
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。
|
||||||
|
|
||||||
以下のスクリプトがあります。
|
以下のスクリプトがあります。
|
||||||
|
|
||||||
* DreamBooth、U-NetおよびText Encoderの学習をサポート
|
* DreamBooth、U-NetおよびText Encoderの学習をサポート
|
||||||
* fine-tuning、同上
|
* fine-tuning、同上
|
||||||
* 画像生成
|
* 画像生成
|
||||||
* モデル変換(Stable Diffision ckpt/safetensorsとDiffusersの相互変換)
|
* モデル変換(Stable Diffision ckpt/safetensorsとDiffusersの相互変換)
|
||||||
|
|
||||||
|
## 使用法について
|
||||||
|
|
||||||
|
note.comに記事がありますのでそちらをご覧ください(将来的にはこちらへ移すかもしれません)。
|
||||||
|
|
||||||
|
* [環境整備とDreamBooth学習スクリプトについて](https://note.com/kohya_ss/n/nee3ed1649fb6)
|
||||||
|
* [fine-tuningスクリプト](https://note.com/kohya_ss/n/nbf7ce8d80f29):
|
||||||
|
Including BLIP captioning and tagging by DeepDanbooru or WD14 tagger
|
||||||
|
* [画像生成スクリプト](https://note.com/kohya_ss/n/n2693183a798e)
|
||||||
|
* [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
|
||||||
|
|||||||
20
README.md
20
README.md
@@ -2,11 +2,27 @@ This repository contains training, generation and utility scripts for Stable Dif
|
|||||||
|
|
||||||
[日本語版README](./README-ja.md)
|
[日本語版README](./README-ja.md)
|
||||||
|
|
||||||
This repository currently contains:
|
For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais!
|
||||||
|
|
||||||
|
This repository contains the scripts for:
|
||||||
|
|
||||||
* DreamBooth training, including U-Net and Text Encoder
|
* DreamBooth training, including U-Net and Text Encoder
|
||||||
* fine-tuning (native training), including U-Net and Text Encoder
|
* fine-tuning (native training), including U-Net and Text Encoder
|
||||||
* image generation
|
* image generation
|
||||||
* model conversion (Stable Diffision ckpt/safetensors and Diffusers)
|
* model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
|
||||||
|
|
||||||
|
## About requirements_*.txt
|
||||||
|
|
||||||
|
These files do not contain requirements for PyTorch and Diffusers. Because the versions of them depend on your environment. Please install PyTorch at first, then Diffusers.
|
||||||
|
|
||||||
|
The scripts is tested with PyTorch 1.12.1 and 1.13.0, Diffusers 0.10.2.
|
||||||
|
|
||||||
|
## Links to how-to-use documents
|
||||||
|
|
||||||
|
All documents are in Japanese currently, and CUI based.
|
||||||
|
|
||||||
|
* [Environment setup and DreamBooth training guide](https://note.com/kohya_ss/n/nee3ed1649fb6)
|
||||||
|
* [Fine-tuning step-by-step guide](https://note.com/kohya_ss/n/nbf7ce8d80f29):
|
||||||
|
Including BLIP captioning and tagging by DeepDanbooru or WD14 tagger
|
||||||
|
* [Image generation](https://note.com/kohya_ss/n/n2693183a798e)
|
||||||
|
* [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
|
||||||
|
|||||||
123
clean_captions_and_tags.py
Normal file
123
clean_captions_and_tags.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||||
|
# (c) 2022 Kohya S. @kohya_ss
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def clean_tags(image_key, tags):
|
||||||
|
# replace '_' to ' '
|
||||||
|
tags = tags.replace('_', ' ')
|
||||||
|
|
||||||
|
# remove rating: deepdanbooruのみ
|
||||||
|
tokens = tags.split(", rating")
|
||||||
|
if len(tokens) == 1:
|
||||||
|
# WD14 taggerのときはこちらになるのでメッセージは出さない
|
||||||
|
# print("no rating:")
|
||||||
|
# print(f"{image_key} {tags}")
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if len(tokens) > 2:
|
||||||
|
print("multiple ratings:")
|
||||||
|
print(f"{image_key} {tags}")
|
||||||
|
tags = tokens[0]
|
||||||
|
|
||||||
|
return tags
|
||||||
|
|
||||||
|
|
||||||
|
# 上から順に検索、置換される
|
||||||
|
# ('置換元文字列', '置換後文字列')
|
||||||
|
CAPTION_REPLACEMENTS = [
|
||||||
|
('anime anime', 'anime'),
|
||||||
|
('young ', ''),
|
||||||
|
('anime girl', 'girl'),
|
||||||
|
('cartoon female', 'girl'),
|
||||||
|
('cartoon lady', 'girl'),
|
||||||
|
('cartoon character', 'girl'), # a or ~s
|
||||||
|
('cartoon woman', 'girl'),
|
||||||
|
('cartoon women', 'girls'),
|
||||||
|
('cartoon girl', 'girl'),
|
||||||
|
('anime female', 'girl'),
|
||||||
|
('anime lady', 'girl'),
|
||||||
|
('anime character', 'girl'), # a or ~s
|
||||||
|
('anime woman', 'girl'),
|
||||||
|
('anime women', 'girls'),
|
||||||
|
('lady', 'girl'),
|
||||||
|
('female', 'girl'),
|
||||||
|
('woman', 'girl'),
|
||||||
|
('women', 'girls'),
|
||||||
|
('people', 'girls'),
|
||||||
|
('person', 'girl'),
|
||||||
|
('a cartoon figure', 'a figure'),
|
||||||
|
('a cartoon image', 'an image'),
|
||||||
|
('a cartoon picture', 'a picture'),
|
||||||
|
('an anime cartoon image', 'an image'),
|
||||||
|
('a cartoon anime drawing', 'a drawing'),
|
||||||
|
('a cartoon drawing', 'a drawing'),
|
||||||
|
('girl girl', 'girl'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def clean_caption(caption):
|
||||||
|
for rf, rt in CAPTION_REPLACEMENTS:
|
||||||
|
replaced = True
|
||||||
|
while replaced:
|
||||||
|
bef = caption
|
||||||
|
caption = caption.replace(rf, rt)
|
||||||
|
replaced = bef != caption
|
||||||
|
return caption
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
if os.path.exists(args.in_json):
|
||||||
|
print(f"loading existing metadata: {args.in_json}")
|
||||||
|
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
else:
|
||||||
|
print("no metadata / メタデータファイルがありません")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("cleaning captions and tags.")
|
||||||
|
image_keys = list(metadata.keys())
|
||||||
|
for image_key in tqdm(image_keys):
|
||||||
|
tags = metadata[image_key].get('tags')
|
||||||
|
if tags is None:
|
||||||
|
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
|
||||||
|
else:
|
||||||
|
metadata[image_key]['tags'] = clean_tags(image_key, tags)
|
||||||
|
|
||||||
|
caption = metadata[image_key].get('caption')
|
||||||
|
if caption is None:
|
||||||
|
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
|
||||||
|
else:
|
||||||
|
metadata[image_key]['caption'] = clean_caption(caption)
|
||||||
|
|
||||||
|
# metadataを書き出して終わり
|
||||||
|
print(f"writing metadata: {args.out_json}")
|
||||||
|
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
print("done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
|
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||||
|
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||||
|
|
||||||
|
args, unknown = parser.parse_known_args()
|
||||||
|
if len(unknown) == 1:
|
||||||
|
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
|
||||||
|
print("All captions and tags in the metadata are processed.")
|
||||||
|
print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
|
||||||
|
print("メタデータ内のすべてのキャプションとタグが処理されます。")
|
||||||
|
args.in_json = args.out_json
|
||||||
|
args.out_json = unknown[0]
|
||||||
|
elif len(unknown) > 0:
|
||||||
|
raise ValueError(f"error: unrecognized arguments: {unknown}")
|
||||||
|
|
||||||
|
main(args)
|
||||||
93
convert_diffusers20_original_sd.py
Normal file
93
convert_diffusers20_original_sd.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# convert Diffusers v1.x/v2.0 model to original Stable Diffusion
|
||||||
|
# v1: initial version
|
||||||
|
# v2: support safetensors
|
||||||
|
# v3: fix to support another format
|
||||||
|
# v4: support safetensors in Diffusers
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
|
||||||
|
import model_util
|
||||||
|
|
||||||
|
|
||||||
|
def convert(args):
|
||||||
|
# 引数を確認する
|
||||||
|
load_dtype = torch.float16 if args.fp16 else None
|
||||||
|
|
||||||
|
save_dtype = None
|
||||||
|
if args.fp16:
|
||||||
|
save_dtype = torch.float16
|
||||||
|
elif args.bf16:
|
||||||
|
save_dtype = torch.bfloat16
|
||||||
|
elif args.float:
|
||||||
|
save_dtype = torch.float
|
||||||
|
|
||||||
|
is_load_ckpt = os.path.isfile(args.model_to_load)
|
||||||
|
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
||||||
|
|
||||||
|
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
||||||
|
assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
||||||
|
|
||||||
|
# モデルを読み込む
|
||||||
|
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
||||||
|
print(f"loading {msg}: {args.model_to_load}")
|
||||||
|
|
||||||
|
if is_load_ckpt:
|
||||||
|
v2_model = args.v2
|
||||||
|
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)
|
||||||
|
else:
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None)
|
||||||
|
text_encoder = pipe.text_encoder
|
||||||
|
vae = pipe.vae
|
||||||
|
unet = pipe.unet
|
||||||
|
|
||||||
|
if args.v1 == args.v2:
|
||||||
|
# 自動判定する
|
||||||
|
v2_model = unet.config.cross_attention_dim == 1024
|
||||||
|
print("checking model version: model is " + ('v2' if v2_model else 'v1'))
|
||||||
|
else:
|
||||||
|
v2_model = args.v1
|
||||||
|
|
||||||
|
# 変換して保存する
|
||||||
|
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
|
||||||
|
print(f"converting and saving as {msg}: {args.model_to_save}")
|
||||||
|
|
||||||
|
if is_save_ckpt:
|
||||||
|
original_model = args.model_to_load if is_load_ckpt else None
|
||||||
|
key_count = model_util.save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet,
|
||||||
|
original_model, args.epoch, args.global_step, save_dtype, vae)
|
||||||
|
print(f"model saved. total converted state_dict keys: {key_count}")
|
||||||
|
else:
|
||||||
|
print(f"copy scheduler/tokenizer config from: {args.reference_model}")
|
||||||
|
model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors)
|
||||||
|
print(f"model saved.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--v1", action='store_true',
|
||||||
|
help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
|
||||||
|
parser.add_argument("--v2", action='store_true',
|
||||||
|
help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')
|
||||||
|
parser.add_argument("--fp16", action='store_true',
|
||||||
|
help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)')
|
||||||
|
parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)')
|
||||||
|
parser.add_argument("--float", action='store_true',
|
||||||
|
help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)')
|
||||||
|
parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')
|
||||||
|
parser.add_argument("--global_step", type=int, default=0,
|
||||||
|
help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
|
||||||
|
parser.add_argument("--reference_model", type=str, default=None,
|
||||||
|
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
|
||||||
|
parser.add_argument("--use_safetensors", action='store_true',
|
||||||
|
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)")
|
||||||
|
|
||||||
|
parser.add_argument("model_to_load", type=str, default=None,
|
||||||
|
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
|
||||||
|
parser.add_argument("model_to_save", type=str, default=None,
|
||||||
|
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert(args)
|
||||||
239
detect_face_rotate.py
Normal file
239
detect_face_rotate.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
|
||||||
|
# (c) 2022 Kohya S. @kohya_ss
|
||||||
|
|
||||||
|
# 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す
|
||||||
|
|
||||||
|
# v2: extract max face if multiple faces are found
|
||||||
|
# v3: add crop_ratio option
|
||||||
|
# v4: add multple faces extraction and min/max size
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import cv2
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
from anime_face_detector import create_detector
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
KP_REYE = 11
|
||||||
|
KP_LEYE = 19
|
||||||
|
|
||||||
|
SCORE_THRES = 0.90
|
||||||
|
|
||||||
|
|
||||||
|
def detect_faces(detector, image, min_size):
|
||||||
|
preds = detector(image) # bgr
|
||||||
|
# print(len(preds))
|
||||||
|
|
||||||
|
faces = []
|
||||||
|
for pred in preds:
|
||||||
|
bb = pred['bbox']
|
||||||
|
score = bb[-1]
|
||||||
|
if score < SCORE_THRES:
|
||||||
|
continue
|
||||||
|
|
||||||
|
left, top, right, bottom = bb[:4]
|
||||||
|
cx = int((left + right) / 2)
|
||||||
|
cy = int((top + bottom) / 2)
|
||||||
|
fw = int(right - left)
|
||||||
|
fh = int(bottom - top)
|
||||||
|
|
||||||
|
lex, ley = pred['keypoints'][KP_LEYE, 0:2]
|
||||||
|
rex, rey = pred['keypoints'][KP_REYE, 0:2]
|
||||||
|
angle = math.atan2(ley - rey, lex - rex)
|
||||||
|
angle = angle / math.pi * 180
|
||||||
|
|
||||||
|
faces.append((cx, cy, fw, fh, angle))
|
||||||
|
|
||||||
|
faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 大きい順
|
||||||
|
return faces
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_image(image, angle, cx, cy):
|
||||||
|
h, w = image.shape[0:2]
|
||||||
|
rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
|
||||||
|
|
||||||
|
# # 回転する分、すこし画像サイズを大きくする→とりあえず無効化
|
||||||
|
# nh = max(h, int(w * math.sin(angle)))
|
||||||
|
# nw = max(w, int(h * math.sin(angle)))
|
||||||
|
# if nh > h or nw > w:
|
||||||
|
# pad_y = nh - h
|
||||||
|
# pad_t = pad_y // 2
|
||||||
|
# pad_x = nw - w
|
||||||
|
# pad_l = pad_x // 2
|
||||||
|
# m = np.array([[0, 0, pad_l],
|
||||||
|
# [0, 0, pad_t]])
|
||||||
|
# rot_mat = rot_mat + m
|
||||||
|
# h, w = nh, nw
|
||||||
|
# cx += pad_l
|
||||||
|
# cy += pad_t
|
||||||
|
|
||||||
|
result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
|
||||||
|
return result, cx, cy
|
||||||
|
|
||||||
|
|
||||||
|
def process(args):
|
||||||
|
assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません"
|
||||||
|
assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
|
||||||
|
|
||||||
|
# アニメ顔検出モデルを読み込む
|
||||||
|
print("loading face detector.")
|
||||||
|
detector = create_detector('yolov3')
|
||||||
|
|
||||||
|
# cropの引数を解析する
|
||||||
|
if args.crop_size is None:
|
||||||
|
crop_width = crop_height = None
|
||||||
|
else:
|
||||||
|
tokens = args.crop_size.split(',')
|
||||||
|
assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください"
|
||||||
|
crop_width, crop_height = [int(t) for t in tokens]
|
||||||
|
|
||||||
|
if args.crop_ratio is None:
|
||||||
|
crop_h_ratio = crop_v_ratio = None
|
||||||
|
else:
|
||||||
|
tokens = args.crop_ratio.split(',')
|
||||||
|
assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください"
|
||||||
|
crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
|
||||||
|
|
||||||
|
# 画像を処理する
|
||||||
|
print("processing.")
|
||||||
|
output_extension = ".png"
|
||||||
|
|
||||||
|
os.makedirs(args.dst_dir, exist_ok=True)
|
||||||
|
paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \
|
||||||
|
glob.glob(os.path.join(args.src_dir, "*.webp"))
|
||||||
|
for path in tqdm(paths):
|
||||||
|
basename = os.path.splitext(os.path.basename(path))[0]
|
||||||
|
|
||||||
|
# image = cv2.imread(path) # 日本語ファイル名でエラーになる
|
||||||
|
image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED)
|
||||||
|
if len(image.shape) == 2:
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
||||||
|
if image.shape[2] == 4:
|
||||||
|
print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
|
||||||
|
image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
|
||||||
|
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
|
||||||
|
faces = detect_faces(detector, image, args.multiple_faces)
|
||||||
|
for i, face in enumerate(faces):
|
||||||
|
cx, cy, fw, fh, angle = face
|
||||||
|
face_size = max(fw, fh)
|
||||||
|
if args.min_size is not None and face_size < args.min_size:
|
||||||
|
continue
|
||||||
|
if args.max_size is not None and face_size >= args.max_size:
|
||||||
|
continue
|
||||||
|
face_suffix = f"_{i+1:02d}" if args.multiple_faces else ""
|
||||||
|
|
||||||
|
# オプション指定があれば回転する
|
||||||
|
face_img = image
|
||||||
|
if args.rotate:
|
||||||
|
face_img, cx, cy = rotate_image(face_img, angle, cx, cy)
|
||||||
|
|
||||||
|
# オプション指定があれば顔を中心に切り出す
|
||||||
|
if crop_width is not None or crop_h_ratio is not None:
|
||||||
|
cur_crop_width, cur_crop_height = crop_width, crop_height
|
||||||
|
if crop_h_ratio is not None:
|
||||||
|
cur_crop_width = int(face_size * crop_h_ratio + .5)
|
||||||
|
cur_crop_height = int(face_size * crop_v_ratio + .5)
|
||||||
|
|
||||||
|
# リサイズを必要なら行う
|
||||||
|
scale = 1.0
|
||||||
|
if args.resize_face_size is not None:
|
||||||
|
# 顔サイズを基準にリサイズする
|
||||||
|
scale = args.resize_face_size / face_size
|
||||||
|
if scale < cur_crop_width / w:
|
||||||
|
print(
|
||||||
|
f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
||||||
|
scale = cur_crop_width / w
|
||||||
|
if scale < cur_crop_height / h:
|
||||||
|
print(
|
||||||
|
f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
||||||
|
scale = cur_crop_height / h
|
||||||
|
elif crop_h_ratio is not None:
|
||||||
|
# 倍率指定の時にはリサイズしない
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# 切り出しサイズ指定あり
|
||||||
|
if w < cur_crop_width:
|
||||||
|
print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
|
||||||
|
scale = cur_crop_width / w
|
||||||
|
if h < cur_crop_height:
|
||||||
|
print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
|
||||||
|
scale = cur_crop_height / h
|
||||||
|
if args.resize_fit:
|
||||||
|
scale = max(cur_crop_width / w, cur_crop_height / h)
|
||||||
|
|
||||||
|
if scale != 1.0:
|
||||||
|
w = int(w * scale + .5)
|
||||||
|
h = int(h * scale + .5)
|
||||||
|
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4)
|
||||||
|
cx = int(cx * scale + .5)
|
||||||
|
cy = int(cy * scale + .5)
|
||||||
|
fw = int(fw * scale + .5)
|
||||||
|
fh = int(fh * scale + .5)
|
||||||
|
|
||||||
|
cur_crop_width = min(cur_crop_width, face_img.shape[1])
|
||||||
|
cur_crop_height = min(cur_crop_height, face_img.shape[0])
|
||||||
|
|
||||||
|
x = cx - cur_crop_width // 2
|
||||||
|
cx = cur_crop_width // 2
|
||||||
|
if x < 0:
|
||||||
|
cx = cx + x
|
||||||
|
x = 0
|
||||||
|
elif x + cur_crop_width > w:
|
||||||
|
cx = cx + (x + cur_crop_width - w)
|
||||||
|
x = w - cur_crop_width
|
||||||
|
face_img = face_img[:, x:x+cur_crop_width]
|
||||||
|
|
||||||
|
y = cy - cur_crop_height // 2
|
||||||
|
cy = cur_crop_height // 2
|
||||||
|
if y < 0:
|
||||||
|
cy = cy + y
|
||||||
|
y = 0
|
||||||
|
elif y + cur_crop_height > h:
|
||||||
|
cy = cy + (y + cur_crop_height - h)
|
||||||
|
y = h - cur_crop_height
|
||||||
|
face_img = face_img[y:y + cur_crop_height]
|
||||||
|
|
||||||
|
# # debug
|
||||||
|
# print(path, cx, cy, angle)
|
||||||
|
# crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
|
||||||
|
# cv2.imshow("image", crp)
|
||||||
|
# if cv2.waitKey() == 27:
|
||||||
|
# break
|
||||||
|
# cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
# debug
|
||||||
|
if args.debug:
|
||||||
|
cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20)
|
||||||
|
|
||||||
|
_, buf = cv2.imencode(output_extension, face_img)
|
||||||
|
with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f:
|
||||||
|
buf.tofile(f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ")
|
||||||
|
parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ")
|
||||||
|
parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する")
|
||||||
|
parser.add_argument("--resize_fit", action="store_true",
|
||||||
|
help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする")
|
||||||
|
parser.add_argument("--resize_face_size", type=int, default=None,
|
||||||
|
help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする")
|
||||||
|
parser.add_argument("--crop_size", type=str, default=None,
|
||||||
|
help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す")
|
||||||
|
parser.add_argument("--crop_ratio", type=str, default=None,
|
||||||
|
help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す")
|
||||||
|
parser.add_argument("--min_size", type=int, default=None,
|
||||||
|
help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)")
|
||||||
|
parser.add_argument("--max_size", type=int, default=None,
|
||||||
|
help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)")
|
||||||
|
parser.add_argument("--multiple_faces", action="store_true",
|
||||||
|
help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す")
|
||||||
|
parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
process(args)
|
||||||
1059
fine_tune.py
Normal file
1059
fine_tune.py
Normal file
File diff suppressed because it is too large
Load Diff
2517
gen_img_diffusers.py
Normal file
2517
gen_img_diffusers.py
Normal file
File diff suppressed because it is too large
Load Diff
96
hypernetwork_nai.py
Normal file
96
hypernetwork_nai.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
# NAI compatible
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class HypernetworkModule(torch.nn.Module):
|
||||||
|
def __init__(self, dim, multiplier=1.0):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
linear1 = torch.nn.Linear(dim, dim * 2)
|
||||||
|
linear2 = torch.nn.Linear(dim * 2, dim)
|
||||||
|
linear1.weight.data.normal_(mean=0.0, std=0.01)
|
||||||
|
linear1.bias.data.zero_()
|
||||||
|
linear2.weight.data.normal_(mean=0.0, std=0.01)
|
||||||
|
linear2.bias.data.zero_()
|
||||||
|
linears = [linear1, linear2]
|
||||||
|
|
||||||
|
self.linear = torch.nn.Sequential(*linears)
|
||||||
|
self.multiplier = multiplier
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + self.linear(x) * self.multiplier
|
||||||
|
|
||||||
|
|
||||||
|
class Hypernetwork(torch.nn.Module):
|
||||||
|
enable_sizes = [320, 640, 768, 1280]
|
||||||
|
# return self.modules[Hypernetwork.enable_sizes.index(size)]
|
||||||
|
|
||||||
|
def __init__(self, multiplier=1.0) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.modules = []
|
||||||
|
for size in Hypernetwork.enable_sizes:
|
||||||
|
self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
|
||||||
|
self.register_module(f"{size}_0", self.modules[-1][0])
|
||||||
|
self.register_module(f"{size}_1", self.modules[-1][1])
|
||||||
|
|
||||||
|
def apply_to_stable_diffusion(self, text_encoder, vae, unet):
|
||||||
|
blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
|
||||||
|
for block in blocks:
|
||||||
|
for subblk in block:
|
||||||
|
if 'SpatialTransformer' in str(type(subblk)):
|
||||||
|
for tf_block in subblk.transformer_blocks:
|
||||||
|
for attn in [tf_block.attn1, tf_block.attn2]:
|
||||||
|
size = attn.context_dim
|
||||||
|
if size in Hypernetwork.enable_sizes:
|
||||||
|
attn.hypernetwork = self
|
||||||
|
else:
|
||||||
|
attn.hypernetwork = None
|
||||||
|
|
||||||
|
def apply_to_diffusers(self, text_encoder, vae, unet):
|
||||||
|
blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
|
||||||
|
for block in blocks:
|
||||||
|
if hasattr(block, 'attentions'):
|
||||||
|
for subblk in block.attentions:
|
||||||
|
if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~
|
||||||
|
for tf_block in subblk.transformer_blocks:
|
||||||
|
for attn in [tf_block.attn1, tf_block.attn2]:
|
||||||
|
size = attn.to_k.in_features
|
||||||
|
if size in Hypernetwork.enable_sizes:
|
||||||
|
attn.hypernetwork = self
|
||||||
|
else:
|
||||||
|
attn.hypernetwork = None
|
||||||
|
return True # TODO error checking
|
||||||
|
|
||||||
|
def forward(self, x, context):
|
||||||
|
size = context.shape[-1]
|
||||||
|
assert size in Hypernetwork.enable_sizes
|
||||||
|
module = self.modules[Hypernetwork.enable_sizes.index(size)]
|
||||||
|
return module[0].forward(context), module[1].forward(context)
|
||||||
|
|
||||||
|
def load_from_state_dict(self, state_dict):
|
||||||
|
# old ver to new ver
|
||||||
|
changes = {
|
||||||
|
'linear1.bias': 'linear.0.bias',
|
||||||
|
'linear1.weight': 'linear.0.weight',
|
||||||
|
'linear2.bias': 'linear.1.bias',
|
||||||
|
'linear2.weight': 'linear.1.weight',
|
||||||
|
}
|
||||||
|
for key_from, key_to in changes.items():
|
||||||
|
if key_from in state_dict:
|
||||||
|
state_dict[key_to] = state_dict[key_from]
|
||||||
|
del state_dict[key_from]
|
||||||
|
|
||||||
|
for size, sd in state_dict.items():
|
||||||
|
if type(size) == int:
|
||||||
|
self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
|
||||||
|
self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_state_dict(self):
|
||||||
|
state_dict = {}
|
||||||
|
for i, size in enumerate(Hypernetwork.enable_sizes):
|
||||||
|
sd0 = self.modules[i][0].state_dict()
|
||||||
|
sd1 = self.modules[i][1].state_dict()
|
||||||
|
state_dict[size] = [sd0, sd1]
|
||||||
|
return state_dict
|
||||||
98
make_captions.py
Normal file
98
make_captions.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||||
|
# (c) 2022 Kohya S. @kohya_ss
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
from models.blip import blip_decoder
|
||||||
|
# from Salesforce_BLIP.models.blip import blip_decoder
|
||||||
|
|
||||||
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||||
|
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||||
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
|
print(f"loading BLIP caption: {args.caption_weights}")
|
||||||
|
image_size = 384
|
||||||
|
model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large')
|
||||||
|
model.eval()
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
print("BLIP loaded")
|
||||||
|
|
||||||
|
# 正方形でいいのか? という気がするがソースがそうなので
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||||
|
])
|
||||||
|
|
||||||
|
# captioningする
|
||||||
|
def run_batch(path_imgs):
|
||||||
|
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if args.beam_search:
|
||||||
|
captions = model.generate(imgs, sample=False, num_beams=args.num_beams,
|
||||||
|
max_length=args.max_length, min_length=args.min_length)
|
||||||
|
else:
|
||||||
|
captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
|
||||||
|
|
||||||
|
for (image_path, _), caption in zip(path_imgs, captions):
|
||||||
|
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
||||||
|
f.write(caption + "\n")
|
||||||
|
if args.debug:
|
||||||
|
print(image_path, caption)
|
||||||
|
|
||||||
|
b_imgs = []
|
||||||
|
for image_path in tqdm(image_paths, smoothing=0.0):
|
||||||
|
raw_image = Image.open(image_path)
|
||||||
|
if raw_image.mode != "RGB":
|
||||||
|
print(f"convert image mode {raw_image.mode} to RGB: {image_path}")
|
||||||
|
raw_image = raw_image.convert("RGB")
|
||||||
|
|
||||||
|
image = transform(raw_image)
|
||||||
|
b_imgs.append((image_path, image))
|
||||||
|
if len(b_imgs) >= args.batch_size:
|
||||||
|
run_batch(b_imgs)
|
||||||
|
b_imgs.clear()
|
||||||
|
if len(b_imgs) > 0:
|
||||||
|
run_batch(b_imgs)
|
||||||
|
|
||||||
|
print("done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
|
parser.add_argument("caption_weights", type=str,
|
||||||
|
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
|
||||||
|
parser.add_argument("--caption_extention", type=str, default=None,
|
||||||
|
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||||
|
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||||
|
parser.add_argument("--beam_search", action="store_true",
|
||||||
|
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||||
|
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
|
||||||
|
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
|
||||||
|
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
|
||||||
|
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
|
||||||
|
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# スペルミスしていたオプションを復元する
|
||||||
|
if args.caption_extention is not None:
|
||||||
|
args.caption_extension = args.caption_extention
|
||||||
|
|
||||||
|
main(args)
|
||||||
68
merge_captions_to_metadata.py
Normal file
68
merge_captions_to_metadata.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||||
|
# (c) 2022 Kohya S. @kohya_ss
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||||
|
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||||
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
|
if args.in_json is None and os.path.isfile(args.out_json):
|
||||||
|
args.in_json = args.out_json
|
||||||
|
|
||||||
|
if args.in_json is not None:
|
||||||
|
print(f"loading existing metadata: {args.in_json}")
|
||||||
|
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
|
||||||
|
else:
|
||||||
|
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
print("merge caption texts to metadata json.")
|
||||||
|
for image_path in tqdm(image_paths):
|
||||||
|
caption_path = os.path.splitext(image_path)[0] + args.caption_extension
|
||||||
|
with open(caption_path, "rt", encoding='utf-8') as f:
|
||||||
|
caption = f.readlines()[0].strip()
|
||||||
|
|
||||||
|
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||||
|
if image_key not in metadata:
|
||||||
|
metadata[image_key] = {}
|
||||||
|
|
||||||
|
metadata[image_key]['caption'] = caption
|
||||||
|
if args.debug:
|
||||||
|
print(image_key, caption)
|
||||||
|
|
||||||
|
# metadataを書き出して終わり
|
||||||
|
print(f"writing metadata: {args.out_json}")
|
||||||
|
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
print("done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
|
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||||
|
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
|
||||||
|
parser.add_argument("--caption_extention", type=str, default=None,
|
||||||
|
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||||
|
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
|
||||||
|
parser.add_argument("--full_path", action="store_true",
|
||||||
|
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
||||||
|
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# スペルミスしていたオプションを復元する
|
||||||
|
if args.caption_extention is not None:
|
||||||
|
args.caption_extension = args.caption_extention
|
||||||
|
|
||||||
|
main(args)
|
||||||
60
merge_dd_tags_to_metadata.py
Normal file
60
merge_dd_tags_to_metadata.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||||
|
# (c) 2022 Kohya S. @kohya_ss
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||||
|
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||||
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
|
if args.in_json is None and os.path.isfile(args.out_json):
|
||||||
|
args.in_json = args.out_json
|
||||||
|
|
||||||
|
if args.in_json is not None:
|
||||||
|
print(f"loading existing metadata: {args.in_json}")
|
||||||
|
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
|
||||||
|
else:
|
||||||
|
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
print("merge tags to metadata json.")
|
||||||
|
for image_path in tqdm(image_paths):
|
||||||
|
tags_path = os.path.splitext(image_path)[0] + '.txt'
|
||||||
|
with open(tags_path, "rt", encoding='utf-8') as f:
|
||||||
|
tags = f.readlines()[0].strip()
|
||||||
|
|
||||||
|
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||||
|
if image_key not in metadata:
|
||||||
|
metadata[image_key] = {}
|
||||||
|
|
||||||
|
metadata[image_key]['tags'] = tags
|
||||||
|
if args.debug:
|
||||||
|
print(image_key, tags)
|
||||||
|
|
||||||
|
# metadataを書き出して終わり
|
||||||
|
print(f"writing metadata: {args.out_json}")
|
||||||
|
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
print("done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
|
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||||
|
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
|
||||||
|
parser.add_argument("--full_path", action="store_true",
|
||||||
|
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
||||||
|
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
1182
model_util.py
Normal file
1182
model_util.py
Normal file
File diff suppressed because it is too large
Load Diff
177
prepare_buckets_latents.py
Normal file
177
prepare_buckets_latents.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||||
|
# (c) 2022 Kohya S. @kohya_ss
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
from diffusers import AutoencoderKL
|
||||||
|
from PIL import Image
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
import model_util
|
||||||
|
|
||||||
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
IMAGE_TRANSFORMS = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.5], [0.5]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_latents(vae, images, weight_dtype):
|
||||||
|
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
|
||||||
|
img_tensors = torch.stack(img_tensors)
|
||||||
|
img_tensors = img_tensors.to(DEVICE, weight_dtype)
|
||||||
|
with torch.no_grad():
|
||||||
|
latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||||
|
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||||
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
|
if os.path.exists(args.in_json):
|
||||||
|
print(f"loading existing metadata: {args.in_json}")
|
||||||
|
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
else:
|
||||||
|
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
||||||
|
return
|
||||||
|
|
||||||
|
weight_dtype = torch.float32
|
||||||
|
if args.mixed_precision == "fp16":
|
||||||
|
weight_dtype = torch.float16
|
||||||
|
elif args.mixed_precision == "bf16":
|
||||||
|
weight_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
|
||||||
|
vae.eval()
|
||||||
|
vae.to(DEVICE, dtype=weight_dtype)
|
||||||
|
|
||||||
|
# bucketのサイズを計算する
|
||||||
|
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
|
||||||
|
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||||
|
|
||||||
|
bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
||||||
|
max_reso, args.min_bucket_reso, args.max_bucket_reso)
|
||||||
|
|
||||||
|
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
|
||||||
|
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
|
||||||
|
buckets_imgs = [[] for _ in range(len(bucket_resos))]
|
||||||
|
bucket_counts = [0 for _ in range(len(bucket_resos))]
|
||||||
|
img_ar_errors = []
|
||||||
|
for i, image_path in enumerate(tqdm(image_paths, smoothing=0.0)):
|
||||||
|
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||||
|
if image_key not in metadata:
|
||||||
|
metadata[image_key] = {}
|
||||||
|
|
||||||
|
image = Image.open(image_path)
|
||||||
|
if image.mode != 'RGB':
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
aspect_ratio = image.width / image.height
|
||||||
|
ar_errors = bucket_aspect_ratios - aspect_ratio
|
||||||
|
bucket_id = np.abs(ar_errors).argmin()
|
||||||
|
reso = bucket_resos[bucket_id]
|
||||||
|
ar_error = ar_errors[bucket_id]
|
||||||
|
img_ar_errors.append(abs(ar_error))
|
||||||
|
|
||||||
|
# どのサイズにリサイズするか→トリミングする方向で
|
||||||
|
if ar_error <= 0: # 横が長い→縦を合わせる
|
||||||
|
scale = reso[1] / image.height
|
||||||
|
else:
|
||||||
|
scale = reso[0] / image.width
|
||||||
|
|
||||||
|
resized_size = (int(image.width * scale + .5), int(image.height * scale + .5))
|
||||||
|
|
||||||
|
# print(image.width, image.height, bucket_id, bucket_resos[bucket_id], ar_errors[bucket_id], resized_size,
|
||||||
|
# bucket_resos[bucket_id][0] - resized_size[0], bucket_resos[bucket_id][1] - resized_size[1])
|
||||||
|
|
||||||
|
assert resized_size[0] == reso[0] or resized_size[1] == reso[
|
||||||
|
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||||
|
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
||||||
|
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||||
|
|
||||||
|
# 画像をリサイズしてトリミングする
|
||||||
|
# PILにinter_areaがないのでcv2で……
|
||||||
|
image = np.array(image)
|
||||||
|
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
|
||||||
|
if resized_size[0] > reso[0]:
|
||||||
|
trim_size = resized_size[0] - reso[0]
|
||||||
|
image = image[:, trim_size//2:trim_size//2 + reso[0]]
|
||||||
|
elif resized_size[1] > reso[1]:
|
||||||
|
trim_size = resized_size[1] - reso[1]
|
||||||
|
image = image[trim_size//2:trim_size//2 + reso[1]]
|
||||||
|
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||||
|
|
||||||
|
# # debug
|
||||||
|
# cv2.imwrite(f"r:\\test\\img_{i:05d}.jpg", image[:, :, ::-1])
|
||||||
|
|
||||||
|
# バッチへ追加
|
||||||
|
buckets_imgs[bucket_id].append((image_key, reso, image))
|
||||||
|
bucket_counts[bucket_id] += 1
|
||||||
|
metadata[image_key]['train_resolution'] = reso
|
||||||
|
|
||||||
|
# バッチを推論するか判定して推論する
|
||||||
|
is_last = i == len(image_paths) - 1
|
||||||
|
for j in range(len(buckets_imgs)):
|
||||||
|
bucket = buckets_imgs[j]
|
||||||
|
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
||||||
|
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
|
||||||
|
|
||||||
|
for (image_key, reso, _), latent in zip(bucket, latents):
|
||||||
|
np.savez(os.path.join(args.train_data_dir, os.path.splitext(os.path.basename(image_key))[0]), latent)
|
||||||
|
|
||||||
|
# flip
|
||||||
|
if args.flip_aug:
|
||||||
|
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
||||||
|
|
||||||
|
for (image_key, reso, _), latent in zip(bucket, latents):
|
||||||
|
np.savez(os.path.join(args.train_data_dir, os.path.splitext(os.path.basename(image_key))[0] + '_flip'), latent)
|
||||||
|
|
||||||
|
bucket.clear()
|
||||||
|
|
||||||
|
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
|
||||||
|
print(f"bucket {i} {reso}: {count}")
|
||||||
|
img_ar_errors = np.array(img_ar_errors)
|
||||||
|
print(f"mean ar error: {np.mean(img_ar_errors)}")
|
||||||
|
|
||||||
|
# metadataを書き出して終わり
|
||||||
|
print(f"writing metadata: {args.out_json}")
|
||||||
|
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
print("done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
|
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||||
|
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||||
|
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
||||||
|
parser.add_argument("--v2", action='store_true',
|
||||||
|
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||||
|
parser.add_argument("--max_resolution", type=str, default="512,512",
|
||||||
|
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
|
||||||
|
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||||
|
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
|
||||||
|
parser.add_argument("--mixed_precision", type=str, default="no",
|
||||||
|
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
||||||
|
parser.add_argument("--full_path", action="store_true",
|
||||||
|
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
||||||
|
parser.add_argument("--flip_aug", action="store_true",
|
||||||
|
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
3
requirements_blip.txt
Normal file
3
requirements_blip.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
timm==0.4.12
|
||||||
|
transformers==4.16.2
|
||||||
|
fairscale==0.4.4
|
||||||
8
requirements_db_finetune.txt
Normal file
8
requirements_db_finetune.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
accelerate==0.14.0
|
||||||
|
transformers>=4.21.0
|
||||||
|
ftfy
|
||||||
|
albumentations
|
||||||
|
opencv-python
|
||||||
|
einops
|
||||||
|
pytorch_lightning
|
||||||
|
safetensors
|
||||||
2
requirements_wd14_tagger.txt
Normal file
2
requirements_wd14_tagger.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
tensorflow<2.11
|
||||||
|
huggingface-hub
|
||||||
143
tag_images_by_wd14_tagger.py
Normal file
143
tag_images_by_wd14_tagger.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||||
|
# (c) 2022 Kohya S. @kohya_ss
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
import cv2
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
from tensorflow.keras.models import load_model
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
# from wd14 tagger
|
||||||
|
IMAGE_SIZE = 448
|
||||||
|
|
||||||
|
WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-vit-tagger'
|
||||||
|
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
||||||
|
SUB_DIR = "variables"
|
||||||
|
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
||||||
|
CSV_FILE = FILES[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
|
||||||
|
# depreacatedの警告が出るけどなくなったらその時
|
||||||
|
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||||
|
if not os.path.exists(args.model_dir) or args.force_download:
|
||||||
|
print("downloading wd14 tagger model from hf_hub")
|
||||||
|
for file in FILES:
|
||||||
|
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
||||||
|
for file in SUB_DIR_FILES:
|
||||||
|
hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
|
||||||
|
args.model_dir, SUB_DIR), force_download=True, force_filename=file)
|
||||||
|
|
||||||
|
# 画像を読み込む
|
||||||
|
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||||
|
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||||
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
|
print("loading model and labels")
|
||||||
|
model = load_model(args.model_dir)
|
||||||
|
|
||||||
|
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
||||||
|
# 依存ライブラリを増やしたくないので自力で読むよ
|
||||||
|
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
|
||||||
|
reader = csv.reader(f)
|
||||||
|
l = [row for row in reader]
|
||||||
|
header = l[0] # tag_id,name,category,count
|
||||||
|
rows = l[1:]
|
||||||
|
assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
|
||||||
|
|
||||||
|
tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ
|
||||||
|
|
||||||
|
# 推論する
|
||||||
|
def run_batch(path_imgs):
|
||||||
|
imgs = np.array([im for _, im in path_imgs])
|
||||||
|
|
||||||
|
probs = model(imgs, training=False)
|
||||||
|
probs = probs.numpy()
|
||||||
|
|
||||||
|
for (image_path, _), prob in zip(path_imgs, probs):
|
||||||
|
# 最初の4つはratingなので無視する
|
||||||
|
# # First 4 labels are actually ratings: pick one with argmax
|
||||||
|
# ratings_names = label_names[:4]
|
||||||
|
# rating_index = ratings_names["probs"].argmax()
|
||||||
|
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
|
||||||
|
|
||||||
|
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
|
||||||
|
# Everything else is tags: pick any where prediction confidence > threshold
|
||||||
|
tag_text = ""
|
||||||
|
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
|
||||||
|
if p >= args.thresh:
|
||||||
|
tag_text += ", " + tags[i]
|
||||||
|
|
||||||
|
if len(tag_text) > 0:
|
||||||
|
tag_text = tag_text[2:] # 最初の ", " を消す
|
||||||
|
|
||||||
|
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
||||||
|
f.write(tag_text + '\n')
|
||||||
|
if args.debug:
|
||||||
|
print(image_path, tag_text)
|
||||||
|
|
||||||
|
b_imgs = []
|
||||||
|
for image_path in tqdm(image_paths, smoothing=0.0):
|
||||||
|
img = Image.open(image_path) # cv2は日本語ファイル名で死ぬのとモード変換したいのでpillowで開く
|
||||||
|
if img.mode != 'RGB':
|
||||||
|
img = img.convert("RGB")
|
||||||
|
img = np.array(img)
|
||||||
|
img = img[:, :, ::-1] # RGB->BGR
|
||||||
|
|
||||||
|
# pad to square
|
||||||
|
size = max(img.shape[0:2])
|
||||||
|
pad_x = size - img.shape[1]
|
||||||
|
pad_y = size - img.shape[0]
|
||||||
|
pad_l = pad_x // 2
|
||||||
|
pad_t = pad_y // 2
|
||||||
|
img = np.pad(img, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
|
||||||
|
|
||||||
|
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
||||||
|
img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
|
||||||
|
# cv2.imshow("img", img)
|
||||||
|
# cv2.waitKey()
|
||||||
|
# cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
img = img.astype(np.float32)
|
||||||
|
b_imgs.append((image_path, img))
|
||||||
|
|
||||||
|
if len(b_imgs) >= args.batch_size:
|
||||||
|
run_batch(b_imgs)
|
||||||
|
b_imgs.clear()
|
||||||
|
|
||||||
|
if len(b_imgs) > 0:
|
||||||
|
run_batch(b_imgs)
|
||||||
|
|
||||||
|
print("done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
|
parser.add_argument("--repo_id", type=str, default=WD14_TAGGER_REPO,
|
||||||
|
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
|
||||||
|
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
|
||||||
|
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
|
||||||
|
parser.add_argument("--force_download", action='store_true',
|
||||||
|
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
|
||||||
|
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||||
|
parser.add_argument("--caption_extention", type=str, default=None,
|
||||||
|
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||||
|
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||||
|
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# スペルミスしていたオプションを復元する
|
||||||
|
if args.caption_extention is not None:
|
||||||
|
args.caption_extension = args.caption_extention
|
||||||
|
|
||||||
|
main(args)
|
||||||
1228
train_db_fixed.py
Normal file
1228
train_db_fixed.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user