From 1945fa186d76fb5545b1f6491b77512bf4953320 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 8 Jan 2023 18:50:52 +0900 Subject: [PATCH] Show error if caption isn't UTF-8, add bmp support --- library/train_util.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 2eb16c00..98ad10ef 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1,5 +1,4 @@ # common functions for training -# TODO test no_token_padding option import argparse import json @@ -42,6 +41,8 @@ DEFAULT_LAST_OUTPUT_NAME = "last" # region dataset +IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"] + class ImageInfo(): def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: @@ -476,7 +477,11 @@ class DreamBoothDataset(BaseDataset): for cap_path in cap_paths: if os.path.isfile(cap_path): with open(cap_path, "rt", encoding='utf-8') as f: - lines = f.readlines() + try: + lines = f.readlines() + except UnicodeDecodeError as e: + print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + raise e assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" caption = lines[0].strip() break @@ -495,8 +500,7 @@ class DreamBoothDataset(BaseDataset): return 0, [], [] caption_by_folder = '_'.join(tokens[1:]) - img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \ - glob.glob(os.path.join(dir, "*.webp")) + img_paths = glob_images(dir, "*") print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う @@ -581,8 +585,7 @@ class FineTuningDataset(BaseDataset): abs_path = image_key else: # わりといい加減だがいい方法が思いつかん - abs_path = (glob.glob(os.path.join(train_data_dir, f"{image_key}.png")) + glob.glob(os.path.join(train_data_dir, f"{image_key}.jpg")) + - glob.glob(os.path.join(train_data_dir, f"{image_key}.webp"))) + abs_path = glob_images(train_data_dir, image_key) assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}" abs_path = abs_path[0] @@ -705,6 +708,12 @@ def debug_dataset(train_dataset): if k == 27 or example['images'] is None: break +def glob_images(dir, base): + img_paths = [] + for ext in IMAGE_EXTENSIONS: + img_paths.extend(glob.glob(os.path.join(dir, base + ext))) + return img_paths + # endregion @@ -1210,6 +1219,10 @@ def patch_accelerator_for_fp16_training(accelerator): def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None): + # with no_token_padding, the length is not max length, return result immediately + if input_ids.size()[-1] != tokenizer.model_max_length: + return text_encoder(input_ids)[0] + b_size = input_ids.size()[0] input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77