Merge branch 'dev' into main

This commit is contained in:
Kohya S
2023-02-04 18:16:03 +09:00
committed by GitHub
14 changed files with 664 additions and 208 deletions

3
.gitignore vendored
View File

@@ -3,4 +3,5 @@ __pycache__
wd14_tagger_model wd14_tagger_model
venv venv
*.egg-info *.egg-info
build build
.vscode

View File

@@ -116,7 +116,7 @@ accelerate configの質問には以下のように答えてください。bf1
cd sd-scripts cd sd-scripts
git pull git pull
.\venv\Scripts\activate .\venv\Scripts\activate
pip install --upgrade -r <requirement file name> pip install --upgrade -r requirements.txt
``` ```
コマンドが成功すれば新しいバージョンが使用できます。 コマンドが成功すれば新しいバージョンが使用できます。

View File

@@ -6,27 +6,39 @@ __Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__
Note: The LoRA models for SD 2.x is not supported too in Web UI. Note: The LoRA models for SD 2.x is not supported too in Web UI.
- 29 Jan. 2023, 2023/1/29 - 3 Feb. 2023, 2023/2/3
- Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. Thanks to mgz-dev! - Update finetune preprocessing scripts.
- Fixed U-Net ``sample_size`` parameter to ``64`` when converting from SD to Diffusers format, in ``convert_diffusers20_original_sd.py`` - ``.bmp`` and ``.jpeg`` are supported. Thanks to breakcore2 and p1atdev!
- ``--lr_scheduler_num_cycles`` と ``--lr_scheduler_power`` オプションを ``train_network.py`` に追加しました。前者は cosine_with_restarts、後者は polynomial の学習率スケジューラに有効です。mgz-dev氏に感謝します。 - The default weights of ``tag_images_by_wd14_tagger.py`` is now ``SmilingWolf/wd-v1-4-convnext-tagger-v2``. You can specify another model id from ``SmilingWolf`` by ``--repo_id`` option. Thanks to SmilingWolf for the great work.
- ``convert_diffusers20_original_sd.py`` で SD 形式から Diffusers に変換するときの U-Net の ``sample_size`` パラメータを ``64`` に修正しました。 - To change the weight, remove ``wd14_tagger_model`` folder, and run the script again.
- 26 Jan. 2023, 2023/1/26 - ``--max_data_loader_n_workers`` option is added to each script. This option uses the DataLoader for data loading to speed up loading, 20%~30% faster.
- Add Textual Inversion training. Documentation is [here](./train_ti_README-ja.md) (in Japanese.) - Please specify 2 or 4, depends on the number of CPU cores.
- Textual Inversionの学習をサポートしました。ドキュメントは[こちら](./train_ti_README-ja.md)。 - ``--recursive`` option is added to ``merge_dd_tags_to_metadata.py`` and ``merge_captions_to_metadata.py``, only works with ``--full_path``.
- 24 Jan. 2023, 2023/1/24 - ``make_captions_by_git.py`` is added. It uses [GIT microsoft/git-large-textcaps](https://huggingface.co/microsoft/git-large-textcaps) for captioning.
- Change the default save format to ``.safetensors`` for ``train_network.py``. - ``requirements.txt`` is updated. If you use this script, [please update the libraries](https://github.com/kohya-ss/sd-scripts#upgrade).
- Add ``--save_n_epoch_ratio`` option to specify how often to save. Thanks to forestsource! - Usage is almost the same as ``make_captions.py``, but batch size should be smaller.
- For example, if 5 is specified, 5 (or 6) files will be saved in training. - ``--remove_words`` option removes as much text as possible (such as ``the word "XXXX" on it``).
- Add feature to pre-calculate hash to reduce loading time in the extension. Thanks to space-nuko! - ``--skip_existing`` option is added to ``prepare_buckets_latents.py``. Images with existing npz files are ignored by this option.
- Add bucketing metadata. Thanks to space-nuko! - ``clean_captions_and_tags.py`` is updated to remove duplicated or conflicting tags, e.g. ``shirt`` is removed when ``white shirt`` exists. if ``black hair`` is with ``red hair``, both are removed.
- Fix an error with bf16 model in ``gen_img_diffusers.py``. - Tag frequency is added to the metadata in ``train_network.py``. Thanks to space-nuko!
- ``train_network.py`` のモデル保存形式のデフォルトを ``.safetensors`` に変更しました。 - __All tags and number of occurrences of the tag are recorded.__ If you do not want it, disable metadata storing with ``--no_metadata`` option.
- モデルを保存する頻度を指定する ``--save_n_epoch_ratio`` オプションが追加されました。forestsource氏に感謝します。
- たとえば 5 を指定すると、学習終了までに合計で5個または6個のファイルが保存されます - fine tuning用の前処理スクリプト群を更新しました
- 拡張でモデル読み込み時間を短縮するためのハッシュ事前計算の機能を追加しました。space-nuko氏に感謝します。 - 拡張子 ``.bmp`` と ``.jpeg`` をサポートしました。breakcore2氏およびp1atdev氏に感謝します。
- メタデータにbucket情報が追加されました。space-nuko氏に感謝します。 - ``tag_images_by_wd14_tagger.py`` のデフォルトの重みを ``SmilingWolf/wd-v1-4-convnext-tagger-v2`` に更新しました。他の ``SmilingWolf`` 氏の重みも ``--repo_id`` オプションで指定可能です。SmilingWolf氏に感謝します。
- ``gen_img_diffusers.py`` でbf16形式のモデルを読み込んだときのエラーを修正しました - 重みを変更するときには ``wd14_tagger_model`` フォルダを削除してからスクリプトを再実行してください
- ``--max_data_loader_n_workers`` オプションが各スクリプトに追加されました。DataLoaderを用いることで読み込み処理を並列化し、処理を20~30%程度高速化します。
- CPUのコア数に応じて2~4程度の値を指定してください。
- ``--recursive`` オプションを ``merge_dd_tags_to_metadata.py`` と ``merge_captions_to_metadata.py`` に追加しました。``--full_path`` を指定したときのみ使用可能です。
- ``make_captions_by_git.py`` を追加しました。[GIT microsoft/git-large-textcaps](https://huggingface.co/microsoft/git-large-textcaps) を用いてキャプションニングを行います。
- ``requirements.txt`` が更新されていますので、[ライブラリをアップデート](https://github.com/kohya-ss/sd-scripts/blob/main/README-ja.md#%E3%82%A2%E3%83%83%E3%83%97%E3%82%B0%E3%83%AC%E3%83%BC%E3%83%89)してください。
- 使用法は ``make_captions.py``とほぼ同じですがバッチサイズは小さめにしてください。
- ``--remove_words`` オプションを指定するとテキスト読み取りを可能な限り削除します(``the word "XXXX" on it``のようなもの)。
- ``--skip_existing`` を ``prepare_buckets_latents.py`` に追加しました。すでにnpzファイルがある画像の処理をスキップします。
- ``clean_captions_and_tags.py``を重複タグや矛盾するタグを削除するよう機能追加しました。例:``white shirt`` タグがある場合、 ``shirt`` タグは削除されます。また``black hair``と``red hair``の両方がある場合、両方とも削除されます。
- ``train_network.py``で使用されているタグと回数をメタデータに記録するようになりました。space-nuko氏に感謝します。
- __すべてのタグと回数がメタデータに記録されます__ 望まない場合には``--no_metadata option``オプションでメタデータの記録を停止してください。
Stable Diffusion web UI本体で当リポジトリで学習したLoRAモデルによる画像生成がサポートされたようです。 Stable Diffusion web UI本体で当リポジトリで学習したLoRAモデルによる画像生成がサポートされたようです。

View File

@@ -5,13 +5,32 @@ import argparse
import glob import glob
import os import os
import json import json
import re
from tqdm import tqdm from tqdm import tqdm
PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ')
PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ')
# 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する
PATTERNS_REMOVE_IN_MULTI = [
PATTERN_HAIR_LENGTH,
PATTERN_HAIR_CUT,
re.compile(r', [\w\-]+ eyes, '),
re.compile(r', ([\w\-]+ sleeves|sleeveless), '),
# 複数の髪型定義がある場合は削除する
re.compile(
r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '),
]
def clean_tags(image_key, tags): def clean_tags(image_key, tags):
# replace '_' to ' ' # replace '_' to ' '
tags = tags.replace('^_^', '^@@@^')
tags = tags.replace('_', ' ') tags = tags.replace('_', ' ')
tags = tags.replace('^@@@^', '^_^')
# remove rating: deepdanbooruのみ # remove rating: deepdanbooruのみ
tokens = tags.split(", rating") tokens = tags.split(", rating")
@@ -26,6 +45,37 @@ def clean_tags(image_key, tags):
print(f"{image_key} {tags}") print(f"{image_key} {tags}")
tags = tokens[0] tags = tokens[0]
tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
# 複数の人物がいる場合は髪色等のタグを削除する
if 'girls' in tags or 'boys' in tags:
for pat in PATTERNS_REMOVE_IN_MULTI:
found = pat.findall(tags)
if len(found) > 1: # 二つ以上、タグがある
tags = pat.sub("", tags)
# 髪の特殊対応
srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合)
if srch_hair_len:
org = srch_hair_len.group()
tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags)
found = PATTERN_HAIR.findall(tags)
if len(found) > 1:
tags = PATTERN_HAIR.sub("", tags)
if srch_hair_len:
tags = tags.replace(", @@@, ", org) # 戻す
# white shirtとshirtみたいな重複タグの削除
found = PATTERN_WORD.findall(tags)
for word in found:
if re.search(f", ((\w+) )+{word}, ", tags):
tags = tags.replace(f", {word}, ", "")
tags = tags.replace(", , ", ", ")
assert tags.startswith(", ") and tags.endswith(", ")
tags = tags[2:-2]
return tags return tags
@@ -88,13 +138,23 @@ def main(args):
if tags is None: if tags is None:
print(f"image does not have tags / メタデータにタグがありません: {image_key}") print(f"image does not have tags / メタデータにタグがありません: {image_key}")
else: else:
metadata[image_key]['tags'] = clean_tags(image_key, tags) org = tags
tags = clean_tags(image_key, tags)
metadata[image_key]['tags'] = tags
if args.debug and org != tags:
print("FROM: " + org)
print("TO: " + tags)
caption = metadata[image_key].get('caption') caption = metadata[image_key].get('caption')
if caption is None: if caption is None:
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}") print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
else: else:
metadata[image_key]['caption'] = clean_caption(caption) org = caption
caption = clean_caption(caption)
metadata[image_key]['caption'] = caption
if args.debug and org != caption:
print("FROM: " + org)
print("TO: " + caption)
# metadataを書き出して終わり # metadataを書き出して終わり
print(f"writing metadata: {args.out_json}") print(f"writing metadata: {args.out_json}")
@@ -108,6 +168,7 @@ if __name__ == '__main__':
# parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") # parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--debug", action="store_true", help="debug mode")
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
if len(unknown) == 1: if len(unknown) == 1:

View File

@@ -11,18 +11,59 @@ import torch
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from blip.blip import blip_decoder from blip.blip import blip_decoder
# from Salesforce_BLIP.models.blip import blip_decoder import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMAGE_SIZE = 384
# 正方形でいいのか? という気がするがソースがそうなので
IMAGE_TRANSFORM = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
# 共通化したいが微妙に処理が異なる……
class ImageLoadingTransformDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
# convert to tensor temporarily so dataloader will accept it
tensor = IMAGE_TRANSFORM(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def main(args): def main(args):
# fix the seed for reproducibility # fix the seed for reproducibility
seed = args.seed # + utils.get_rank() seed = args.seed # + utils.get_rank()
torch.manual_seed(seed) torch.manual_seed(seed)
np.random.seed(seed) np.random.seed(seed)
random.seed(seed) random.seed(seed)
if not os.path.exists("blip"): if not os.path.exists("blip"):
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
@@ -31,24 +72,15 @@ def main(args):
os.chdir('finetune') os.chdir('finetune')
print(f"load images from {args.train_data_dir}") print(f"load images from {args.train_data_dir}")
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ image_paths = train_util.glob_images(args.train_data_dir)
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
print(f"loading BLIP caption: {args.caption_weights}") print(f"loading BLIP caption: {args.caption_weights}")
image_size = 384 model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config="./blip/med_config.json")
model.eval() model.eval()
model = model.to(DEVICE) model = model.to(DEVICE)
print("BLIP loaded") print("BLIP loaded")
# 正方形でいいのか? という気がするがソースがそうなので
transform = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
# captioningする # captioningする
def run_batch(path_imgs): def run_batch(path_imgs):
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE) imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
@@ -66,18 +98,35 @@ def main(args):
if args.debug: if args.debug:
print(image_path, caption) print(image_path, caption)
b_imgs = [] # 読み込みの高速化のためにDataLoaderを使うオプション
for image_path in tqdm(image_paths, smoothing=0.0): if args.max_data_loader_n_workers is not None:
raw_image = Image.open(image_path) dataset = ImageLoadingTransformDataset(image_paths)
if raw_image.mode != "RGB": data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
print(f"convert image mode {raw_image.mode} to RGB: {image_path}") num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
raw_image = raw_image.convert("RGB") else:
data = [[(None, ip)] for ip in image_paths]
image = transform(raw_image) b_imgs = []
b_imgs.append((image_path, image)) for data_entry in tqdm(data, smoothing=0.0):
if len(b_imgs) >= args.batch_size: for data in data_entry:
run_batch(b_imgs) if data is None:
b_imgs.clear() continue
img_tensor, image_path = data
if img_tensor is None:
try:
raw_image = Image.open(image_path)
if raw_image.mode != 'RGB':
raw_image = raw_image.convert("RGB")
img_tensor = IMAGE_TRANSFORM(raw_image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, img_tensor))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0: if len(b_imgs) > 0:
run_batch(b_imgs) run_batch(b_imgs)
@@ -95,6 +144,8 @@ if __name__ == '__main__':
parser.add_argument("--beam_search", action="store_true", parser.add_argument("--beam_search", action="store_true",
help="use beam search (default Nucleus sampling) / beam searchを使うこのオプション未指定時はNucleus sampling") help="use beam search (default Nucleus sampling) / beam searchを使うこのオプション未指定時はNucleus sampling")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数多いと精度が上がるが時間がかかる") parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数多いと精度が上がるが時間がかかる")
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")

View File

@@ -0,0 +1,145 @@
import argparse
import os
import re
from PIL import Image
from tqdm import tqdm
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.generation.utils import GenerationMixin
import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
PATTERN_REPLACE = [
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
re.compile(r'with the number \d+ on (it|\w+ \w+)'),
re.compile(r'with the words "'),
re.compile(r'word \w+ on it'),
re.compile(r'that says the word \w+ on it'),
re.compile('that says\'the word "( on it)?'),
]
# 誤検知しまくりの with the word xxxx を消す
def remove_words(captions, debug):
removed_caps = []
for caption in captions:
cap = caption
for pat in PATTERN_REPLACE:
cap = pat.sub("", cap)
if debug and cap != caption:
print(caption)
print(cap)
removed_caps.append(cap)
return removed_caps
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def main(args):
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
# input_idsがバッチサイズと同じ件数である必要があるバッチサイズはこの関数から参照できないので外から渡す
# ここより上で置き換えようとするとすごく大変
def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
if input_ids.size()[0] != curr_batch_size[0]:
input_ids = input_ids.repeat(curr_batch_size[0], 1)
return input_ids
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
print(f"load images from {args.train_data_dir}")
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
# できればcacheに依存せず明示的にダウンロードしたい
print(f"loading GIT: {args.model_id}")
git_processor = AutoProcessor.from_pretrained(args.model_id)
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
print("GIT loaded")
# captioningする
def run_batch(path_imgs):
imgs = [im for _, im in path_imgs]
curr_batch_size[0] = len(path_imgs)
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
if args.remove_words:
captions = remove_words(captions, args.debug)
for (image_path, _), caption in zip(path_imgs, captions):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
image, image_path = data
if image is None:
try:
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps",
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
parser.add_argument("--remove_words", action="store_true",
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
parser.add_argument("--debug", action="store_true", help="debug mode")
args = parser.parse_args()
main(args)

View File

@@ -1,26 +1,24 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse import argparse
import glob
import os
import json import json
from pathlib import Path
from typing import List
from tqdm import tqdm from tqdm import tqdm
import library.train_util as train_util
def main(args): def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
if args.in_json is None and os.path.isfile(args.out_json): if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json args.in_json = args.out_json
if args.in_json is not None: if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}") print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f: metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
metadata = json.load(f)
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
else: else:
print("new metadata will be created / 新しいメタデータファイルが作成されます") print("new metadata will be created / 新しいメタデータファイルが作成されます")
@@ -28,11 +26,10 @@ def main(args):
print("merge caption texts to metadata json.") print("merge caption texts to metadata json.")
for image_path in tqdm(image_paths): for image_path in tqdm(image_paths):
caption_path = os.path.splitext(image_path)[0] + args.caption_extension caption_path = image_path.with_suffix(args.caption_extension)
with open(caption_path, "rt", encoding='utf-8') as f: caption = caption_path.read_text(encoding='utf-8').strip()
caption = f.readlines()[0].strip()
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata: if image_key not in metadata:
metadata[image_key] = {} metadata[image_key] = {}
@@ -42,8 +39,7 @@ def main(args):
# metadataを書き出して終わり # metadataを書き出して終わり
print(f"writing metadata: {args.out_json}") print(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f: Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
json.dump(metadata, f, indent=2)
print("done!") print("done!")
@@ -51,12 +47,15 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む") parser.add_argument("--in_json", type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--caption_extention", type=str, default=None, parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)") help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子") parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
parser.add_argument("--full_path", action="store_true", parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
args = parser.parse_args() args = parser.parse_args()

View File

@@ -1,26 +1,24 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse import argparse
import glob
import os
import json import json
from pathlib import Path
from typing import List
from tqdm import tqdm from tqdm import tqdm
import library.train_util as train_util
def main(args): def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
if args.in_json is None and os.path.isfile(args.out_json): if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json args.in_json = args.out_json
if args.in_json is not None: if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}") print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f: metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
metadata = json.load(f)
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
else: else:
print("new metadata will be created / 新しいメタデータファイルが作成されます") print("new metadata will be created / 新しいメタデータファイルが作成されます")
@@ -28,11 +26,10 @@ def main(args):
print("merge tags to metadata json.") print("merge tags to metadata json.")
for image_path in tqdm(image_paths): for image_path in tqdm(image_paths):
tags_path = os.path.splitext(image_path)[0] + '.txt' tags_path = image_path.with_suffix(args.caption_extension)
with open(tags_path, "rt", encoding='utf-8') as f: tags = tags_path.read_text(encoding='utf-8').strip()
tags = f.readlines()[0].strip()
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata: if image_key not in metadata:
metadata[image_key] = {} metadata[image_key] = {}
@@ -42,8 +39,8 @@ def main(args):
# metadataを書き出して終わり # metadataを書き出して終わり
print(f"writing metadata: {args.out_json}") print(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f: Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
json.dump(metadata, f, indent=2)
print("done!") print("done!")
@@ -51,9 +48,14 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む") parser.add_argument("--in_json", type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--full_path", action="store_true", parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--caption_extension", type=str, default=".txt",
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode, print tags") parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
args = parser.parse_args() args = parser.parse_args()

View File

@@ -1,20 +1,16 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse import argparse
import glob
import os import os
import json import json
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
from diffusers import AutoencoderKL
from PIL import Image from PIL import Image
import cv2 import cv2
import torch import torch
from torchvision import transforms from torchvision import transforms
import library.model_util as model_util import library.model_util as model_util
import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -26,6 +22,16 @@ IMAGE_TRANSFORMS = transforms.Compose(
) )
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def get_latents(vae, images, weight_dtype): def get_latents(vae, images, weight_dtype):
img_tensors = [IMAGE_TRANSFORMS(image) for image in images] img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
img_tensors = torch.stack(img_tensors) img_tensors = torch.stack(img_tensors)
@@ -35,9 +41,18 @@ def get_latents(vae, images, weight_dtype):
return latents return latents
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
if is_full_path:
base_name = os.path.splitext(os.path.basename(image_key))[0]
else:
base_name = image_key
if flip:
base_name += '_flip'
return os.path.join(data_dir, base_name)
def main(args): def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ image_paths = train_util.glob_images(args.train_data_dir)
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
if os.path.exists(args.in_json): if os.path.exists(args.in_json):
@@ -70,15 +85,56 @@ def main(args):
buckets_imgs = [[] for _ in range(len(bucket_resos))] buckets_imgs = [[] for _ in range(len(bucket_resos))]
bucket_counts = [0 for _ in range(len(bucket_resos))] bucket_counts = [0 for _ in range(len(bucket_resos))]
img_ar_errors = [] img_ar_errors = []
for i, image_path in enumerate(tqdm(image_paths, smoothing=0.0)):
def process_batch(is_last):
for j in range(len(buckets_imgs)):
bucket = buckets_imgs[j]
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
for (image_key, _, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
np.savez(npz_file_name, latent)
# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, _, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
np.savez(npz_file_name, latent)
bucket.clear()
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
for data_entry in tqdm(data, smoothing=0.0):
if data_entry[0] is None:
continue
img_tensor, image_path = data_entry[0]
if img_tensor is not None:
image = transforms.functional.to_pil_image(img_tensor)
else:
try:
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
if image_key not in metadata: if image_key not in metadata:
metadata[image_key] = {} metadata[image_key] = {}
image = Image.open(image_path) # 本当はこの部分もDataSetに持っていけば高速化できるがいろいろ大変
if image.mode != 'RGB':
image = image.convert("RGB")
aspect_ratio = image.width / image.height aspect_ratio = image.width / image.height
ar_errors = bucket_aspect_ratios - aspect_ratio ar_errors = bucket_aspect_ratios - aspect_ratio
bucket_id = np.abs(ar_errors).argmin() bucket_id = np.abs(ar_errors).argmin()
@@ -102,6 +158,25 @@ def main(args):
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[ assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}" 1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
# 既に存在するファイルがあればshapeを確認して同じならskipする
if args.skip_existing:
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
if args.flip_aug:
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
found = True
for npz_file in npz_files:
if not os.path.exists(npz_file):
found = False
break
dat = np.load(npz_file)['arr_0']
if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
found = False
break
if found:
continue
# 画像をリサイズしてトリミングする # 画像をリサイズしてトリミングする
# PILにinter_areaがないのでcv2で…… # PILにinter_areaがないのでcv2で……
image = np.array(image) image = np.array(image)
@@ -123,25 +198,10 @@ def main(args):
metadata[image_key]['train_resolution'] = reso metadata[image_key]['train_resolution'] = reso
# バッチを推論するか判定して推論する # バッチを推論するか判定して推論する
is_last = i == len(image_paths) - 1 process_batch(False)
for j in range(len(buckets_imgs)):
bucket = buckets_imgs[j]
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
for (image_key, reso, _), latent in zip(bucket, latents): # 残りを処理する
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key process_batch(True)
np.savez(os.path.join(args.train_data_dir, npz_file_name), latent)
# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, reso, _), latent in zip(bucket, latents):
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key
np.savez(os.path.join(args.train_data_dir, npz_file_name + '_flip'), latent)
bucket.clear()
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)): for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
print(f"bucket {i} {reso}: {count}") print(f"bucket {i} {reso}: {count}")
@@ -162,8 +222,10 @@ if __name__ == '__main__':
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action='store_true', parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)')
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--max_resolution", type=str, default="512,512", parser.add_argument("--max_resolution", type=str, default="512,512",
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)") help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
@@ -174,6 +236,8 @@ if __name__ == '__main__':
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--flip_aug", action="store_true", parser.add_argument("--flip_aug", action="store_true",
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する") help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
parser.add_argument("--skip_existing", action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@@ -1,6 +1,3 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse import argparse
import csv import csv
import glob import glob
@@ -12,32 +9,87 @@ from tqdm import tqdm
import numpy as np import numpy as np
from tensorflow.keras.models import load_model from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import torch
import library.train_util as train_util
# from wd14 tagger # from wd14 tagger
IMAGE_SIZE = 448 IMAGE_SIZE = 448
WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-vit-tagger' # wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
SUB_DIR = "variables" SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = FILES[-1] CSV_FILE = FILES[-1]
def preprocess_image(image):
image = np.array(image)
image = image[:, :, ::-1] # RGB->BGR
# pad to square
size = max(image.shape[0:2])
pad_x = size - image.shape[1]
pad_y = size - image.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
image = image.astype(np.float32)
return image
class ImageLoadingPrepDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
image = preprocess_image(image)
tensor = torch.tensor(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def main(args): def main(args):
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
# depreacatedの警告が出るけどなくなったらその時 # depreacatedの警告が出るけどなくなったらその時
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download: if not os.path.exists(args.model_dir) or args.force_download:
print("downloading wd14 tagger model from hf_hub") print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
for file in FILES: for file in FILES:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES: for file in SUB_DIR_FILES:
hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join( hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
args.model_dir, SUB_DIR), force_download=True, force_filename=file) args.model_dir, SUB_DIR), force_download=True, force_filename=file)
else:
print("using existing wd14 tagger model")
# 画像を読み込む # 画像を読み込む
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ image_paths = train_util.glob_images(args.train_data_dir)
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
print("loading model and labels") print("loading model and labels")
@@ -72,7 +124,7 @@ def main(args):
# Everything else is tags: pick any where prediction confidence > threshold # Everything else is tags: pick any where prediction confidence > threshold
tag_text = "" tag_text = ""
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
if p >= args.thresh: if p >= args.thresh and i < len(tags):
tag_text += ", " + tags[i] tag_text += ", " + tags[i]
if len(tag_text) > 0: if len(tag_text) > 0:
@@ -83,34 +135,37 @@ def main(args):
if args.debug: if args.debug:
print(image_path, tag_text) print(image_path, tag_text)
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingPrepDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
b_imgs = [] b_imgs = []
for image_path in tqdm(image_paths, smoothing=0.0): for data_entry in tqdm(data, smoothing=0.0):
img = Image.open(image_path) # cv2は日本語ファイル名で死ぬのとモード変換したいのでpillowで開く for data in data_entry:
if img.mode != 'RGB': if data is None:
img = img.convert("RGB") continue
img = np.array(img)
img = img[:, :, ::-1] # RGB->BGR
# pad to square image, image_path = data
size = max(img.shape[0:2]) if image is not None:
pad_x = size - img.shape[1] image = image.detach().numpy()
pad_y = size - img.shape[0] else:
pad_l = pad_x // 2 try:
pad_t = pad_y // 2 image = Image.open(image_path)
img = np.pad(img, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255) if image.mode != 'RGB':
image = image.convert("RGB")
image = preprocess_image(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 if len(b_imgs) >= args.batch_size:
img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) run_batch(b_imgs)
# cv2.imshow("img", img) b_imgs.clear()
# cv2.waitKey()
# cv2.destroyAllWindows()
img = img.astype(np.float32)
b_imgs.append((image_path, img))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0: if len(b_imgs) > 0:
run_batch(b_imgs) run_batch(b_imgs)
@@ -121,7 +176,7 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--repo_id", type=str, default=WD14_TAGGER_REPO, parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID") help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model", parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ") help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
@@ -129,6 +184,8 @@ if __name__ == '__main__':
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします") help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--caption_extention", type=str, default=None, parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")

View File

@@ -1845,12 +1845,12 @@ def main(args):
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
else: else:
print("load Diffusers pretrained models") print("load Diffusers pretrained models")
pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
text_encoder = pipe.text_encoder text_encoder = loading_pipe.text_encoder
vae = pipe.vae vae = loading_pipe.vae
unet = pipe.unet unet = loading_pipe.unet
tokenizer = pipe.tokenizer tokenizer = loading_pipe.tokenizer
del pipe del loading_pipe
# VAEを読み込む # VAEを読み込む
if args.vae is not None: if args.vae is not None:

View File

@@ -45,6 +45,7 @@ DEFAULT_LAST_OUTPUT_NAME = "last"
# region dataset # region dataset
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"] IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
class ImageInfo(): class ImageInfo():
@@ -87,6 +88,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.enable_bucket = False self.enable_bucket = False
self.min_bucket_reso = None self.min_bucket_reso = None
self.max_bucket_reso = None self.max_bucket_reso = None
self.tag_frequency = {}
self.bucket_info = None self.bucket_info = None
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
@@ -115,6 +117,16 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements = {} self.replacements = {}
def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {})
self.tag_frequency[dir_name] = frequency_for_dir
for caption in captions:
for tag in caption.split(","):
if tag and not tag.isspace():
tag = tag.lower()
frequency = frequency_for_dir.get(tag, 0)
frequency_for_dir[tag] = frequency + 1
def disable_token_padding(self): def disable_token_padding(self):
self.token_padding_disabled = True self.token_padding_disabled = True
@@ -247,7 +259,6 @@ class BaseDataset(torch.utils.data.Dataset):
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
print(f"mean ar error (without repeats): {mean_img_ar_error}") print(f"mean ar error (without repeats): {mean_img_ar_error}")
# 参照用indexを作る # 参照用indexを作る
self.buckets_indices: list(BucketBatchIndex) = [] self.buckets_indices: list(BucketBatchIndex) = []
for bucket_index, bucket in enumerate(self.buckets): for bucket_index, bucket in enumerate(self.buckets):
@@ -545,6 +556,8 @@ class DreamBoothDataset(BaseDataset):
cap_for_img = read_caption(img_path) cap_for_img = read_caption(img_path)
captions.append(caption_by_folder if cap_for_img is None else cap_for_img) captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
return n_repeats, img_paths, captions return n_repeats, img_paths, captions
print("prepare train images.") print("prepare train images.")
@@ -553,10 +566,13 @@ class DreamBoothDataset(BaseDataset):
for dir in train_dirs: for dir in train_dirs:
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir)) n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
num_train_images += n_repeats * len(img_paths) num_train_images += n_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions): for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, False, img_path) info = ImageInfo(img_path, n_repeats, caption, False, img_path)
self.register_image(info) self.register_image(info)
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)} self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_train_images} train images with repeating.") print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images self.num_train_images = num_train_images
@@ -570,9 +586,11 @@ class DreamBoothDataset(BaseDataset):
for dir in reg_dirs: for dir in reg_dirs:
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir)) n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
num_reg_images += n_repeats * len(img_paths) num_reg_images += n_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions): for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, True, img_path) info = ImageInfo(img_path, n_repeats, caption, True, img_path)
reg_infos.append(info) reg_infos.append(info)
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)} self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_reg_images} reg images.") print(f"{num_reg_images} reg images.")
@@ -617,6 +635,7 @@ class FineTuningDataset(BaseDataset):
self.train_data_dir = train_data_dir self.train_data_dir = train_data_dir
self.batch_size = batch_size self.batch_size = batch_size
tags_list = []
for image_key, img_md in metadata.items(): for image_key, img_md in metadata.items():
# path情報を作る # path情報を作る
if os.path.exists(image_key): if os.path.exists(image_key):
@@ -633,6 +652,7 @@ class FineTuningDataset(BaseDataset):
caption = tags caption = tags
elif tags is not None and len(tags) > 0: elif tags is not None and len(tags) > 0:
caption = caption + ', ' + tags caption = caption + ', ' + tags
tags_list.append(tags)
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}" assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path) image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
@@ -646,7 +666,8 @@ class FineTuningDataset(BaseDataset):
self.num_train_images = len(metadata) * dataset_repeats self.num_train_images = len(metadata) * dataset_repeats
self.num_reg_images = 0 self.num_reg_images = 0
self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)} self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
# check existence of all npz files # check existence of all npz files
if not self.color_aug: if not self.color_aug:
@@ -667,6 +688,8 @@ class FineTuningDataset(BaseDataset):
print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します") print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します")
elif not npz_all: elif not npz_all:
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
if self.flip_aug:
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
for image_info in self.image_data.values(): for image_info in self.image_data.values():
image_info.latents_npz = image_info.latents_npz_flipped = None image_info.latents_npz = image_info.latents_npz_flipped = None
@@ -756,15 +779,30 @@ def debug_dataset(train_dataset, show_input_ids=False):
break break
def glob_images(dir, base): def glob_images(directory, base="*"):
img_paths = [] img_paths = []
for ext in IMAGE_EXTENSIONS: for ext in IMAGE_EXTENSIONS:
if base == '*': if base == '*':
img_paths.extend(glob.glob(os.path.join(glob.escape(dir), base + ext))) img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
else: else:
img_paths.extend(glob.glob(glob.escape(os.path.join(dir, base + ext)))) img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
# img_paths = list(set(img_paths)) # 重複を排除
# img_paths.sort()
return img_paths return img_paths
def glob_images_pathlib(dir_path, recursive):
image_paths = []
if recursive:
for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.rglob('*' + ext))
else:
for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.glob('*' + ext))
# image_paths = list(set(image_paths)) # 重複を排除
# image_paths.sort()
return image_paths
# endregion # endregion
@@ -1497,5 +1535,30 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
# endregion
# region 前処理用
class ImageLoadingDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
# convert to tensor temporarily so dataloader will accept it
tensor_pil = transforms.functional.pil_to_tensor(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor_pil, img_path)
# endregion # endregion

View File

@@ -1,5 +1,5 @@
accelerate==0.15.0 accelerate==0.15.0
transformers==4.25.1 transformers==4.26.0
ftfy ftfy
albumentations albumentations
opencv-python opencv-python

View File

@@ -1,3 +1,6 @@
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from torch.optim import Optimizer
from typing import Optional, Union
import importlib import importlib
import argparse import argparse
import gc import gc
@@ -40,9 +43,6 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
# Which is a newer release of diffusers than currently packaged with sd-scripts # Which is a newer release of diffusers than currently packaged with sd-scripts
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
from typing import Optional, Union
from torch.optim import Optimizer
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
def get_scheduler_fix( def get_scheduler_fix(
name: Union[str, SchedulerType], name: Union[str, SchedulerType],
@@ -52,53 +52,53 @@ def get_scheduler_fix(
num_cycles: int = 1, num_cycles: int = 1,
power: float = 1.0, power: float = 1.0,
): ):
""" """
Unified API to get any scheduler from its name. Unified API to get any scheduler from its name.
Args: Args:
name (`str` or `SchedulerType`): name (`str` or `SchedulerType`):
The name of the scheduler to use. The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`): optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training. The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*): num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it. optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*): num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it. optional), the function will raise an error if it's unset and the scheduler type requires it.
num_cycles (`int`, *optional*): num_cycles (`int`, *optional*):
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
power (`float`, *optional*, defaults to 1.0): power (`float`, *optional*, defaults to 1.0):
Power factor. See `POLYNOMIAL` scheduler Power factor. See `POLYNOMIAL` scheduler
last_epoch (`int`, *optional*, defaults to -1): last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training. The index of the last epoch when resuming training.
""" """
name = SchedulerType(name) name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT: if name == SchedulerType.CONSTANT:
return schedule_func(optimizer) return schedule_func(optimizer)
# All other schedulers require `num_warmup_steps` # All other schedulers require `num_warmup_steps`
if num_warmup_steps is None: if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP: if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
# All other schedulers require `num_training_steps` # All other schedulers require `num_training_steps`
if num_training_steps is None: if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if name == SchedulerType.COSINE_WITH_RESTARTS: if name == SchedulerType.COSINE_WITH_RESTARTS:
return schedule_func( return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
) )
if name == SchedulerType.POLYNOMIAL: if name == SchedulerType.POLYNOMIAL:
return schedule_func( return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
) )
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
def train(args): def train(args):
@@ -135,7 +135,7 @@ def train(args):
train_util.debug_dataset(train_dataset) train_util.debug_dataset(train_dataset)
return return
if len(train_dataset) == 0: if len(train_dataset) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります")
return return
# acceleratorを準備する # acceleratorを準備する
@@ -335,6 +335,7 @@ def train(args):
"ss_keep_tokens": args.keep_tokens, "ss_keep_tokens": args.keep_tokens,
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
"ss_bucket_info": json.dumps(train_dataset.bucket_info), "ss_bucket_info": json.dumps(train_dataset.bucket_info),
"ss_training_comment": args.training_comment # will not be updated after training "ss_training_comment": args.training_comment # will not be updated after training
} }