mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support multiline captions ref #1155
This commit is contained in:
@@ -693,6 +693,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
# process wildcards
|
||||
if subset.enable_wildcard:
|
||||
# if caption is multiline, random choice one line
|
||||
if "\n" in caption:
|
||||
caption = random.choice(caption.split("\n"))
|
||||
|
||||
# wildcard is like '{aaa|bbb|ccc...}'
|
||||
# escape the curly braces like {{ or }}
|
||||
replacer1 = "⦅"
|
||||
@@ -711,6 +715,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# unescape the curly braces
|
||||
caption = caption.replace(replacer1, "{").replace(replacer2, "}")
|
||||
else:
|
||||
# if caption is multiline, use the first line
|
||||
caption = caption.split("\n")[0]
|
||||
|
||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||
fixed_tokens = []
|
||||
@@ -1446,7 +1453,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
def read_caption(img_path, caption_extension):
|
||||
def read_caption(img_path, caption_extension, enable_wildcard):
|
||||
# captionの候補ファイル名を作る
|
||||
base_name = os.path.splitext(img_path)[0]
|
||||
base_name_face_det = base_name
|
||||
@@ -1465,7 +1472,10 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
||||
raise e
|
||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||
caption = lines[0].strip()
|
||||
if enable_wildcard:
|
||||
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
|
||||
else:
|
||||
caption = lines[0].strip()
|
||||
break
|
||||
return caption
|
||||
|
||||
@@ -1481,7 +1491,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
captions = []
|
||||
missing_captions = []
|
||||
for img_path in img_paths:
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension)
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
|
||||
if cap_for_img is None and subset.class_tokens is None:
|
||||
logger.warning(
|
||||
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
|
||||
@@ -1657,10 +1667,24 @@ class FineTuningDataset(BaseDataset):
|
||||
caption = img_md.get("caption")
|
||||
tags = img_md.get("tags")
|
||||
if caption is None:
|
||||
caption = tags
|
||||
elif tags is not None and len(tags) > 0:
|
||||
caption = caption + ", " + tags
|
||||
tags_list.append(tags)
|
||||
caption = tags # could be multiline
|
||||
tags = None
|
||||
|
||||
if subset.enable_wildcard:
|
||||
# tags must be single line
|
||||
if tags is not None:
|
||||
tags = tags.replace("\n", subset.caption_separator)
|
||||
|
||||
# add tags to each line of caption
|
||||
if caption is not None and tags is not None:
|
||||
caption = "\n".join(
|
||||
[f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""]
|
||||
)
|
||||
else:
|
||||
# use as is
|
||||
if tags is not None and len(tags) > 0:
|
||||
caption = caption + subset.caption_separator + tags
|
||||
tags_list.append(tags)
|
||||
|
||||
if caption is None:
|
||||
caption = ""
|
||||
|
||||
Reference in New Issue
Block a user