support multiline captions ref #1155

This commit is contained in:
Kohya S
2024-03-23 18:51:37 +09:00
parent 594c7f7050
commit f4a4c11cd3
4 changed files with 94 additions and 14 deletions

View File

@@ -693,6 +693,10 @@ class BaseDataset(torch.utils.data.Dataset):
else:
# process wildcards
if subset.enable_wildcard:
# if caption is multiline, random choice one line
if "\n" in caption:
caption = random.choice(caption.split("\n"))
# wildcard is like '{aaa|bbb|ccc...}'
# escape the curly braces like {{ or }}
replacer1 = ""
@@ -711,6 +715,9 @@ class BaseDataset(torch.utils.data.Dataset):
# unescape the curly braces
caption = caption.replace(replacer1, "{").replace(replacer2, "}")
else:
# if caption is multiline, use the first line
caption = caption.split("\n")[0]
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
fixed_tokens = []
@@ -1446,7 +1453,7 @@ class DreamBoothDataset(BaseDataset):
self.bucket_reso_steps = None # この情報は使われない
self.bucket_no_upscale = False
def read_caption(img_path, caption_extension):
def read_caption(img_path, caption_extension, enable_wildcard):
# captionの候補ファイル名を作る
base_name = os.path.splitext(img_path)[0]
base_name_face_det = base_name
@@ -1465,7 +1472,10 @@ class DreamBoothDataset(BaseDataset):
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
raise e
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
caption = lines[0].strip()
if enable_wildcard:
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
else:
caption = lines[0].strip()
break
return caption
@@ -1481,7 +1491,7 @@ class DreamBoothDataset(BaseDataset):
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension)
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
if cap_for_img is None and subset.class_tokens is None:
logger.warning(
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
@@ -1657,10 +1667,24 @@ class FineTuningDataset(BaseDataset):
caption = img_md.get("caption")
tags = img_md.get("tags")
if caption is None:
caption = tags
elif tags is not None and len(tags) > 0:
caption = caption + ", " + tags
tags_list.append(tags)
caption = tags # could be multiline
tags = None
if subset.enable_wildcard:
# tags must be single line
if tags is not None:
tags = tags.replace("\n", subset.caption_separator)
# add tags to each line of caption
if caption is not None and tags is not None:
caption = "\n".join(
[f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""]
)
else:
# use as is
if tags is not None and len(tags) > 0:
caption = caption + subset.caption_separator + tags
tags_list.append(tags)
if caption is None:
caption = ""