format by black

This commit is contained in:
Kohya S
2023-04-17 22:00:26 +09:00
parent 8f6fc8daa1
commit 47d61e2c02
5 changed files with 719 additions and 542 deletions

View File

@@ -14,17 +14,20 @@ from torchvision.transforms.functional import InterpolationMode
from blip.blip import blip_decoder from blip.blip import blip_decoder
import library.train_util as train_util import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = 384 IMAGE_SIZE = 384
# 正方形でいいのか? という気がするがソースがそうなので # 正方形でいいのか? という気がするがソースがそうなので
IMAGE_TRANSFORM = transforms.Compose([ IMAGE_TRANSFORM = transforms.Compose(
[
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC), transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
]) ]
)
# 共通化したいが微妙に処理が異なる…… # 共通化したいが微妙に処理が異なる……
class ImageLoadingTransformDataset(torch.utils.data.Dataset): class ImageLoadingTransformDataset(torch.utils.data.Dataset):
@@ -69,8 +72,8 @@ def main(args):
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
cwd = os.getcwd() cwd = os.getcwd()
print('Current Working Directory is: ', cwd) print("Current Working Directory is: ", cwd)
os.chdir('finetune') os.chdir("finetune")
print(f"load images from {args.train_data_dir}") print(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir) train_data_dir_path = Path(args.train_data_dir)
@@ -78,7 +81,7 @@ def main(args):
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
print(f"loading BLIP caption: {args.caption_weights}") print(f"loading BLIP caption: {args.caption_weights}")
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json") model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
model.eval() model.eval()
model = model.to(DEVICE) model = model.to(DEVICE)
print("BLIP loaded") print("BLIP loaded")
@@ -89,13 +92,16 @@ def main(args):
with torch.no_grad(): with torch.no_grad():
if args.beam_search: if args.beam_search:
captions = model.generate(imgs, sample=False, num_beams=args.num_beams, captions = model.generate(
max_length=args.max_length, min_length=args.min_length) imgs, sample=False, num_beams=args.num_beams, max_length=args.max_length, min_length=args.min_length
)
else: else:
captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length) captions = model.generate(
imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length
)
for (image_path, _), caption in zip(path_imgs, captions): for (image_path, _), caption in zip(path_imgs, captions):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n") f.write(caption + "\n")
if args.debug: if args.debug:
print(image_path, caption) print(image_path, caption)
@@ -103,8 +109,14 @@ def main(args):
# 読み込みの高速化のためにDataLoaderを使うオプション # 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None: if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingTransformDataset(image_paths) dataset = ImageLoadingTransformDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, data = torch.utils.data.DataLoader(
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
)
else: else:
data = [[(None, ip)] for ip in image_paths] data = [[(None, ip)] for ip in image_paths]
@@ -118,7 +130,7 @@ def main(args):
if img_tensor is None: if img_tensor is None:
try: try:
raw_image = Image.open(image_path) raw_image = Image.open(image_path)
if raw_image.mode != 'RGB': if raw_image.mode != "RGB":
raw_image = raw_image.convert("RGB") raw_image = raw_image.convert("RGB")
img_tensor = IMAGE_TRANSFORM(raw_image) img_tensor = IMAGE_TRANSFORM(raw_image)
except Exception as e: except Exception as e:
@@ -138,28 +150,43 @@ def main(args):
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", parser.add_argument(
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)") "--caption_weights",
parser.add_argument("--caption_extention", type=str, default=None, type=str,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)",
)
parser.add_argument(
"--caption_extention",
type=str,
default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
)
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--beam_search", action="store_true", parser.add_argument(
help="use beam search (default Nucleus sampling) / beam searchを使うこのオプション未指定時はNucleus sampling") "--beam_search",
action="store_true",
help="use beam search (default Nucleus sampling) / beam searchを使うこのオプション未指定時はNucleus sampling",
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None, parser.add_argument(
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化") "--max_data_loader_n_workers",
type=int,
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化",
)
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数多いと精度が上がるが時間がかかる") parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数多いと精度が上がるが時間がかかる")
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed') parser.add_argument("--seed", default=42, type=int, help="seed for reproducibility / 再現性を確保するための乱数seed")
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
return parser return parser
if __name__ == '__main__': if __name__ == "__main__":
parser = setup_parser() parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()

View File

@@ -12,17 +12,17 @@ from transformers.generation.utils import GenerationMixin
import library.train_util as train_util import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATTERN_REPLACE = [ PATTERN_REPLACE = [
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'), re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'), re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"), re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
re.compile(r'with the number \d+ on (it|\w+ \w+)'), re.compile(r"with the number \d+ on (it|\w+ \w+)"),
re.compile(r'with the words "'), re.compile(r'with the words "'),
re.compile(r'word \w+ on it'), re.compile(r"word \w+ on it"),
re.compile(r'that says the word \w+ on it'), re.compile(r"that says the word \w+ on it"),
re.compile('that says\'the word "( on it)?'), re.compile("that says'the word \"( on it)?"),
] ]
# 誤検知しまくりの with the word xxxx を消す # 誤検知しまくりの with the word xxxx を消す
@@ -63,6 +63,7 @@ def main(args):
if input_ids.size()[0] != curr_batch_size[0]: if input_ids.size()[0] != curr_batch_size[0]:
input_ids = input_ids.repeat(curr_batch_size[0], 1) input_ids = input_ids.repeat(curr_batch_size[0], 1)
return input_ids return input_ids
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
print(f"load images from {args.train_data_dir}") print(f"load images from {args.train_data_dir}")
@@ -89,7 +90,7 @@ def main(args):
captions = remove_words(captions, args.debug) captions = remove_words(captions, args.debug)
for (image_path, _), caption in zip(path_imgs, captions): for (image_path, _), caption in zip(path_imgs, captions):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n") f.write(caption + "\n")
if args.debug: if args.debug:
print(image_path, caption) print(image_path, caption)
@@ -97,8 +98,14 @@ def main(args):
# 読み込みの高速化のためにDataLoaderを使うオプション # 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None: if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths) dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, data = torch.utils.data.DataLoader(
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
)
else: else:
data = [[(None, ip)] for ip in image_paths] data = [[(None, ip)] for ip in image_paths]
@@ -112,7 +119,7 @@ def main(args):
if image is None: if image is None:
try: try:
image = Image.open(image_path) image = Image.open(image_path)
if image.mode != 'RGB': if image.mode != "RGB":
image = image.convert("RGB") image = image.convert("RGB")
except Exception as e: except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
@@ -133,21 +140,32 @@ def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps", parser.add_argument(
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID") "--model_id",
type=str,
default="microsoft/git-large-textcaps",
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID",
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None, parser.add_argument(
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化") "--max_data_loader_n_workers",
type=int,
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化",
)
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長") parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
parser.add_argument("--remove_words", action="store_true", parser.add_argument(
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する") "--remove_words",
action="store_true",
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する",
)
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
return parser return parser
if __name__ == '__main__': if __name__ == "__main__":
parser = setup_parser() parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()

View File

@@ -14,7 +14,7 @@ from torchvision import transforms
import library.model_util as model_util import library.model_util as model_util
import library.train_util as train_util import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_TRANSFORMS = transforms.Compose( IMAGE_TRANSFORMS = transforms.Compose(
[ [
@@ -52,7 +52,7 @@ def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
relative_path = "" relative_path = ""
if flip: if flip:
base_name += '_flip' base_name += "_flip"
if recursive and relative_path: if recursive and relative_path:
return os.path.join(data_dir, relative_path, base_name) return os.path.join(data_dir, relative_path, base_name)
@@ -60,7 +60,6 @@ def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
return os.path.join(data_dir, base_name) return os.path.join(data_dir, base_name)
def main(args): def main(args):
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
if args.bucket_reso_steps % 8 > 0: if args.bucket_reso_steps % 8 > 0:
@@ -72,7 +71,7 @@ def main(args):
if os.path.exists(args.in_json): if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}") print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f: with open(args.in_json, "rt", encoding="utf-8") as f:
metadata = json.load(f) metadata = json.load(f)
else: else:
print(f"no metadata / メタデータファイルがありません: {args.in_json}") print(f"no metadata / メタデータファイルがありません: {args.in_json}")
@@ -89,15 +88,18 @@ def main(args):
vae.to(DEVICE, dtype=weight_dtype) vae.to(DEVICE, dtype=weight_dtype)
# bucketのサイズを計算する # bucketのサイズを計算する
max_reso = tuple([int(t) for t in args.max_resolution.split(',')]) max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
bucket_manager = train_util.BucketManager(args.bucket_no_upscale, max_reso, bucket_manager = train_util.BucketManager(
args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps) args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
)
if not args.bucket_no_upscale: if not args.bucket_no_upscale:
bucket_manager.make_buckets() bucket_manager.make_buckets()
else: else:
print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます") print(
"min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます"
)
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する # 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
img_ar_errors = [] img_ar_errors = []
@@ -106,8 +108,9 @@ def main(args):
for bucket in bucket_manager.buckets: for bucket in bucket_manager.buckets:
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, img in bucket], weight_dtype) latents = get_latents(vae, [img for _, img in bucket], weight_dtype)
assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, \ assert (
f"latent shape {latents.shape}, {bucket[0][1].shape}" latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8
), f"latent shape {latents.shape}, {bucket[0][1].shape}"
for (image_key, _), latent in zip(bucket, latents): for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive)
@@ -118,12 +121,16 @@ def main(args):
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, _), latent in zip(bucket, latents): for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) npz_file_name = get_npz_filename_wo_ext(
args.train_data_dir, image_key, args.full_path, True, args.recursive
)
np.savez(npz_file_name, latent) np.savez(npz_file_name, latent)
else: else:
# remove existing flipped npz # remove existing flipped npz
for image_key, _ in bucket: for image_key, _ in bucket:
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz" npz_file_name = (
get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
)
if os.path.isfile(npz_file_name): if os.path.isfile(npz_file_name):
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}") print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
os.remove(npz_file_name) os.remove(npz_file_name)
@@ -133,8 +140,14 @@ def main(args):
# 読み込みの高速化のためにDataLoaderを使うオプション # 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None: if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths) dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, data = torch.utils.data.DataLoader(
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) dataset,
batch_size=1,
shuffle=False,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
)
else: else:
data = [[(None, ip)] for ip in image_paths] data = [[(None, ip)] for ip in image_paths]
@@ -149,7 +162,7 @@ def main(args):
else: else:
try: try:
image = Image.open(image_path) image = Image.open(image_path)
if image.mode != 'RGB': if image.mode != "RGB":
image = image.convert("RGB") image = image.convert("RGB")
except Exception as e: except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
@@ -166,23 +179,28 @@ def main(args):
bucket_counts[reso] = bucket_counts.get(reso, 0) + 1 bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て # メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8) metadata[image_key]["train_resolution"] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
if not args.bucket_no_upscale: if not args.bucket_no_upscale:
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する # upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
assert resized_size[0] == reso[0] or resized_size[1] == reso[ assert (
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}" resized_size[0] == reso[0] or resized_size[1] == reso[1]
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[ ), f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}" assert (
resized_size[0] >= reso[0] and resized_size[1] >= reso[1]
), f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[ assert (
1], f"internal error resized size is small: {resized_size}, {reso}" resized_size[0] >= reso[0] and resized_size[1] >= reso[1]
), f"internal error resized size is small: {resized_size}, {reso}"
# 既に存在するファイルがあればshapeを確認して同じならskipする # 既に存在するファイルがあればshapeを確認して同じならskipする
if args.skip_existing: if args.skip_existing:
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"] npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"]
if args.flip_aug: if args.flip_aug:
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz") npz_files.append(
get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
)
found = True found = True
for npz_file in npz_files: for npz_file in npz_files:
@@ -190,7 +208,7 @@ def main(args):
found = False found = False
break break
dat = np.load(npz_file)['arr_0'] dat = np.load(npz_file)["arr_0"]
if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認 if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
found = False found = False
break break
@@ -205,13 +223,15 @@ def main(args):
if resized_size[0] > reso[0]: if resized_size[0] > reso[0]:
trim_size = resized_size[0] - reso[0] trim_size = resized_size[0] - reso[0]
image = image[:, trim_size//2:trim_size//2 + reso[0]] image = image[:, trim_size // 2 : trim_size // 2 + reso[0]]
if resized_size[1] > reso[1]: if resized_size[1] > reso[1]:
trim_size = resized_size[1] - reso[1] trim_size = resized_size[1] - reso[1]
image = image[trim_size//2:trim_size//2 + reso[1]] image = image[trim_size // 2 : trim_size // 2 + reso[1]]
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}" assert (
image.shape[0] == reso[1] and image.shape[1] == reso[0]
), f"internal error, illegal trimmed size: {image.shape}, {reso}"
# # debug # # debug
# cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1]) # cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
@@ -235,7 +255,7 @@ def main(args):
# metadataを書き出して終わり # metadataを書き出して終わり
print(f"writing metadata: {args.out_json}") print(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f: with open(args.out_json, "wt", encoding="utf-8") as f:
json.dump(metadata, f, indent=2) json.dump(metadata, f, indent=2)
print("done!") print("done!")
@@ -246,34 +266,57 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action='store_true', parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)')
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None, parser.add_argument(
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化") "--max_data_loader_n_workers",
parser.add_argument("--max_resolution", type=str, default="512,512", type=int,
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)") default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化",
)
parser.add_argument(
"--max_resolution",
type=str,
default="512,512",
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)",
)
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度") parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
parser.add_argument("--bucket_reso_steps", type=int, default=64, parser.add_argument(
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します") "--bucket_reso_steps",
parser.add_argument("--bucket_no_upscale", action="store_true", type=int,
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します") default=64,
parser.add_argument("--mixed_precision", type=str, default="no", help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") )
parser.add_argument("--full_path", action="store_true", parser.add_argument(
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
parser.add_argument("--flip_aug", action="store_true", )
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する") parser.add_argument(
parser.add_argument("--skip_existing", action="store_true", "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ") )
parser.add_argument("--recursive", action="store_true", parser.add_argument(
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") "--full_path",
action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
)
parser.add_argument(
"--skip_existing",
action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
)
parser.add_argument(
"--recursive",
action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
)
return parser return parser
if __name__ == '__main__': if __name__ == "__main__":
parser = setup_parser() parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()

View File

@@ -18,12 +18,13 @@ import library.train_util as train_util
IMAGE_SIZE = 448 IMAGE_SIZE = 448
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2 # wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2' DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
SUB_DIR = "variables" SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = FILES[-1] CSV_FILE = FILES[-1]
def preprocess_image(image): def preprocess_image(image):
image = np.array(image) image = np.array(image)
image = image[:, :, ::-1] # RGB->BGR image = image[:, :, ::-1] # RGB->BGR
@@ -34,7 +35,7 @@ def preprocess_image(image):
pad_y = size - image.shape[0] pad_y = size - image.shape[0]
pad_l = pad_x // 2 pad_l = pad_x // 2
pad_t = pad_y // 2 pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255) image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
@@ -42,6 +43,7 @@ def preprocess_image(image):
image = image.astype(np.float32) image = image.astype(np.float32)
return image return image
class ImageLoadingPrepDataset(torch.utils.data.Dataset): class ImageLoadingPrepDataset(torch.utils.data.Dataset):
def __init__(self, image_paths): def __init__(self, image_paths):
self.images = image_paths self.images = image_paths
@@ -62,6 +64,7 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
return (tensor, img_path) return (tensor, img_path)
def collate_fn_remove_corrupted(batch): def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the """Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs. dataloader. It expects that the dataloader returns 'None' when that occurs.
@@ -71,6 +74,7 @@ def collate_fn_remove_corrupted(batch):
batch = list(filter(lambda x: x is not None, batch)) batch = list(filter(lambda x: x is not None, batch))
return batch return batch
def main(args): def main(args):
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
# depreacatedの警告が出るけどなくなったらその時 # depreacatedの警告が出るけどなくなったらその時
@@ -80,8 +84,14 @@ def main(args):
for file in FILES: for file in FILES:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES: for file in SUB_DIR_FILES:
hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join( hf_hub_download(
args.model_dir, SUB_DIR), force_download=True, force_filename=file) args.repo_id,
file,
subfolder=SUB_DIR,
cache_dir=os.path.join(args.model_dir, SUB_DIR),
force_download=True,
force_filename=file,
)
else: else:
print("using existing wd14 tagger model") print("using existing wd14 tagger model")
@@ -96,10 +106,10 @@ def main(args):
l = [row for row in reader] l = [row for row in reader]
header = l[0] # tag_id,name,category,count header = l[0] # tag_id,name,category,count
rows = l[1:] rows = l[1:]
assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}" assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
general_tags = [row[1] for row in rows[1:] if row[2] == '0'] general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
character_tags = [row[1] for row in rows[1:] if row[2] == '4'] character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
# 画像を読み込む # 画像を読み込む
@@ -109,7 +119,7 @@ def main(args):
tag_freq = {} tag_freq = {}
undesired_tags = set(args.undesired_tags.split(',')) undesired_tags = set(args.undesired_tags.split(","))
def run_batch(path_imgs): def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs]) imgs = np.array([im for _, im in path_imgs])
@@ -131,13 +141,17 @@ def main(args):
character_tag_text = "" character_tag_text = ""
for i, p in enumerate(prob[4:]): for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= args.general_threshold: if i < len(general_tags) and p >= args.general_threshold:
tag_name = general_tags[i].replace('_', ' ') if args.remove_underscore else general_tags[i] tag_name = general_tags[i].replace("_", " ") if args.remove_underscore else general_tags[i]
if tag_name not in undesired_tags: if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += ", " + tag_name general_tag_text += ", " + tag_name
combined_tags.append(tag_name) combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= args.character_threshold: elif i >= len(general_tags) and p >= args.character_threshold:
tag_name = character_tags[i - len(general_tags)].replace('_', ' ') if args.remove_underscore else character_tags[i - len(general_tags)] tag_name = (
character_tags[i - len(general_tags)].replace("_", " ")
if args.remove_underscore
else character_tags[i - len(general_tags)]
)
if tag_name not in undesired_tags: if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += ", " + tag_name character_tag_text += ", " + tag_name
@@ -149,19 +163,24 @@ def main(args):
if len(character_tag_text) > 0: if len(character_tag_text) > 0:
character_tag_text = character_tag_text[2:] character_tag_text = character_tag_text[2:]
tag_text = ', '.join(combined_tags) tag_text = ", ".join(combined_tags)
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(tag_text + '\n') f.write(tag_text + "\n")
if args.debug: if args.debug:
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
# 読み込みの高速化のためにDataLoaderを使うオプション # 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None: if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingPrepDataset(image_paths) dataset = ImageLoadingPrepDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, data = torch.utils.data.DataLoader(
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
)
else: else:
data = [[(None, ip)] for ip in image_paths] data = [[(None, ip)] for ip in image_paths]
@@ -177,7 +196,7 @@ def main(args):
else: else:
try: try:
image = Image.open(image_path) image = Image.open(image_path)
if image.mode != 'RGB': if image.mode != "RGB":
image = image.convert("RGB") image = image.convert("RGB")
image = preprocess_image(image) image = preprocess_image(image)
except Exception as e: except Exception as e:
@@ -203,29 +222,65 @@ def main(args):
print("done!") print("done!")
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO, parser.add_argument(
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID") "--repo_id",
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model", type=str,
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ") default=DEFAULT_WD14_TAGGER_REPO,
parser.add_argument("--force_download", action='store_true', help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID",
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします") )
parser.add_argument(
"--model_dir",
type=str,
default="wd14_tagger_model",
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ",
)
parser.add_argument(
"--force_download", action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None, parser.add_argument(
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化") "--max_data_loader_n_workers",
parser.add_argument("--caption_extention", type=str, default=None, type=int,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化",
)
parser.add_argument(
"--caption_extention",
type=str,
default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
)
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
parser.add_argument("--general_threshold", type=float, default=None, help="threshold of confidence to add a tag for general category, same as --thresh if omitted / generalカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ") parser.add_argument(
parser.add_argument("--character_threshold", type=float, default=None, help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ") "--general_threshold",
type=float,
default=None,
help="threshold of confidence to add a tag for general category, same as --thresh if omitted / generalカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
)
parser.add_argument(
"--character_threshold",
type=float,
default=None,
help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
)
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
parser.add_argument("--remove_underscore", action="store_true", help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える") parser.add_argument(
"--remove_underscore",
action="store_true",
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
)
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument("--undesired_tags", type=str, default="", help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト") parser.add_argument(
parser.add_argument('--frequency_tags', action='store_true', help='Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する') "--undesired_tags",
type=str,
default="",
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
)
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
args = parser.parse_args() args = parser.parse_args()

View File

@@ -24,7 +24,9 @@ def convert(args):
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です" assert (
is_save_ckpt or args.reference_model is not None
), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
# モデルを読み込む # モデルを読み込む
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
@@ -34,7 +36,9 @@ def convert(args):
v2_model = args.v2 v2_model = args.v2
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load) text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)
else: else:
pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None) pipe = StableDiffusionPipeline.from_pretrained(
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None
)
text_encoder = pipe.text_encoder text_encoder = pipe.text_encoder
vae = pipe.vae vae = pipe.vae
unet = pipe.unet unet = pipe.unet
@@ -42,7 +46,7 @@ def convert(args):
if args.v1 == args.v2: if args.v1 == args.v2:
# 自動判定する # 自動判定する
v2_model = unet.config.cross_attention_dim == 1024 v2_model = unet.config.cross_attention_dim == 1024
print("checking model version: model is " + ('v2' if v2_model else 'v1')) print("checking model version: model is " + ("v2" if v2_model else "v1"))
else: else:
v2_model = not args.v1 v2_model = not args.v1
@@ -52,44 +56,74 @@ def convert(args):
if is_save_ckpt: if is_save_ckpt:
original_model = args.model_to_load if is_load_ckpt else None original_model = args.model_to_load if is_load_ckpt else None
key_count = model_util.save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet, key_count = model_util.save_stable_diffusion_checkpoint(
original_model, args.epoch, args.global_step, save_dtype, vae) v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, save_dtype, vae
)
print(f"model saved. total converted state_dict keys: {key_count}") print(f"model saved. total converted state_dict keys: {key_count}")
else: else:
print(f"copy scheduler/tokenizer config from: {args.reference_model}") print(f"copy scheduler/tokenizer config from: {args.reference_model}")
model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors) model_util.save_diffusers_checkpoint(
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
)
print(f"model saved.") print(f"model saved.")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v1", action='store_true', parser.add_argument(
help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む') "--v1", action="store_true", help="load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む"
parser.add_argument("--v2", action='store_true', )
help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む') parser.add_argument(
parser.add_argument("--fp16", action='store_true', "--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込みDiffusers形式のみ対応、保存するcheckpointのみ対応') )
parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存するcheckpointのみ対応') parser.add_argument(
parser.add_argument("--float", action='store_true', "--fp16",
help='save as float (checkpoint only) / float(float32)形式で保存するcheckpointのみ対応') action="store_true",
parser.add_argument("--save_precision_as", type=str, default="no", choices=["fp16", "bf16", "float"], help="load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込みDiffusers形式のみ対応、保存するcheckpointのみ対応",
help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでください") )
parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値') parser.add_argument("--bf16", action="store_true", help="save as bf16 (checkpoint only) / bf16形式で保存するcheckpointのみ対応")
parser.add_argument("--global_step", type=int, default=0, parser.add_argument(
help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値') "--float", action="store_true", help="save as float (checkpoint only) / float(float32)形式で保存するcheckpointのみ対応"
parser.add_argument("--reference_model", type=str, default=None, )
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要") parser.add_argument(
parser.add_argument("--use_safetensors", action='store_true', "--save_precision_as",
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存するcheckpointは拡張子で自動判定") type=str,
default="no",
choices=["fp16", "bf16", "float"],
help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでください",
)
parser.add_argument("--epoch", type=int, default=0, help="epoch to write to checkpoint / checkpointに記録するepoch数の値")
parser.add_argument(
"--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
)
parser.add_argument(
"--reference_model",
type=str,
default=None,
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要",
)
parser.add_argument(
"--use_safetensors",
action="store_true",
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存するcheckpointは拡張子で自動判定",
)
parser.add_argument("model_to_load", type=str, default=None, parser.add_argument(
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ") "model_to_load",
parser.add_argument("model_to_save", type=str, default=None, type=str,
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存") default=None,
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ",
)
parser.add_argument(
"model_to_save",
type=str,
default=None,
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存",
)
return parser return parser
if __name__ == '__main__': if __name__ == "__main__":
parser = setup_parser() parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()