Show error if caption isn't UTF-8, add bmp support

This commit is contained in:
Kohya S
2023-01-08 18:50:52 +09:00
parent 82e585cf01
commit 1945fa186d

View File

@@ -1,5 +1,4 @@
# common functions for training
# TODO test no_token_padding option
import argparse
import json
@@ -42,6 +41,8 @@ DEFAULT_LAST_OUTPUT_NAME = "last"
# region dataset
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
class ImageInfo():
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
@@ -476,7 +477,11 @@ class DreamBoothDataset(BaseDataset):
for cap_path in cap_paths:
if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding='utf-8') as f:
lines = f.readlines()
try:
lines = f.readlines()
except UnicodeDecodeError as e:
print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
raise e
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
caption = lines[0].strip()
break
@@ -495,8 +500,7 @@ class DreamBoothDataset(BaseDataset):
return 0, [], []
caption_by_folder = '_'.join(tokens[1:])
img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \
glob.glob(os.path.join(dir, "*.webp"))
img_paths = glob_images(dir, "*")
print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
@@ -581,8 +585,7 @@ class FineTuningDataset(BaseDataset):
abs_path = image_key
else:
# わりといい加減だがいい方法が思いつかん
abs_path = (glob.glob(os.path.join(train_data_dir, f"{image_key}.png")) + glob.glob(os.path.join(train_data_dir, f"{image_key}.jpg")) +
glob.glob(os.path.join(train_data_dir, f"{image_key}.webp")))
abs_path = glob_images(train_data_dir, image_key)
assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}"
abs_path = abs_path[0]
@@ -705,6 +708,12 @@ def debug_dataset(train_dataset):
if k == 27 or example['images'] is None:
break
def glob_images(dir, base):
img_paths = []
for ext in IMAGE_EXTENSIONS:
img_paths.extend(glob.glob(os.path.join(dir, base + ext)))
return img_paths
# endregion
@@ -1210,6 +1219,10 @@ def patch_accelerator_for_fp16_training(accelerator):
def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None):
# with no_token_padding, the length is not max length, return result immediately
if input_ids.size()[-1] != tokenizer.model_max_length:
return text_encoder(input_ids)[0]
b_size = input_ids.size()[0]
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77