mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Show error if caption isn't UTF-8, add bmp support
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
# common functions for training
|
||||
# TODO test no_token_padding option
|
||||
|
||||
import argparse
|
||||
import json
|
||||
@@ -42,6 +41,8 @@ DEFAULT_LAST_OUTPUT_NAME = "last"
|
||||
|
||||
# region dataset
|
||||
|
||||
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
|
||||
|
||||
|
||||
class ImageInfo():
|
||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||
@@ -476,7 +477,11 @@ class DreamBoothDataset(BaseDataset):
|
||||
for cap_path in cap_paths:
|
||||
if os.path.isfile(cap_path):
|
||||
with open(cap_path, "rt", encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
try:
|
||||
lines = f.readlines()
|
||||
except UnicodeDecodeError as e:
|
||||
print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
||||
raise e
|
||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||
caption = lines[0].strip()
|
||||
break
|
||||
@@ -495,8 +500,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
return 0, [], []
|
||||
|
||||
caption_by_folder = '_'.join(tokens[1:])
|
||||
img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \
|
||||
glob.glob(os.path.join(dir, "*.webp"))
|
||||
img_paths = glob_images(dir, "*")
|
||||
print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
|
||||
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
@@ -581,8 +585,7 @@ class FineTuningDataset(BaseDataset):
|
||||
abs_path = image_key
|
||||
else:
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
abs_path = (glob.glob(os.path.join(train_data_dir, f"{image_key}.png")) + glob.glob(os.path.join(train_data_dir, f"{image_key}.jpg")) +
|
||||
glob.glob(os.path.join(train_data_dir, f"{image_key}.webp")))
|
||||
abs_path = glob_images(train_data_dir, image_key)
|
||||
assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}"
|
||||
abs_path = abs_path[0]
|
||||
|
||||
@@ -705,6 +708,12 @@ def debug_dataset(train_dataset):
|
||||
if k == 27 or example['images'] is None:
|
||||
break
|
||||
|
||||
def glob_images(dir, base):
|
||||
img_paths = []
|
||||
for ext in IMAGE_EXTENSIONS:
|
||||
img_paths.extend(glob.glob(os.path.join(dir, base + ext)))
|
||||
return img_paths
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -1210,6 +1219,10 @@ def patch_accelerator_for_fp16_training(accelerator):
|
||||
|
||||
|
||||
def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None):
|
||||
# with no_token_padding, the length is not max length, return result immediately
|
||||
if input_ids.size()[-1] != tokenizer.model_max_length:
|
||||
return text_encoder(input_ids)[0]
|
||||
|
||||
b_size = input_ids.size()[0]
|
||||
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
|
||||
|
||||
|
||||
Reference in New Issue
Block a user