Merge branch 'dev' into masked-loss

This commit is contained in:
Kohya S
2024-03-24 18:19:53 +09:00
6 changed files with 394 additions and 200 deletions

View File

@@ -693,6 +693,10 @@ class BaseDataset(torch.utils.data.Dataset):
else:
# process wildcards
if subset.enable_wildcard:
# if caption is multiline, random choice one line
if "\n" in caption:
caption = random.choice(caption.split("\n"))
# wildcard is like '{aaa|bbb|ccc...}'
# escape the curly braces like {{ or }}
replacer1 = ""
@@ -711,6 +715,9 @@ class BaseDataset(torch.utils.data.Dataset):
# unescape the curly braces
caption = caption.replace(replacer1, "{").replace(replacer2, "}")
else:
# if caption is multiline, use the first line
caption = caption.split("\n")[0]
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
fixed_tokens = []
@@ -1446,7 +1453,7 @@ class DreamBoothDataset(BaseDataset):
self.bucket_reso_steps = None # この情報は使われない
self.bucket_no_upscale = False
def read_caption(img_path, caption_extension):
def read_caption(img_path, caption_extension, enable_wildcard):
# captionの候補ファイル名を作る
base_name = os.path.splitext(img_path)[0]
base_name_face_det = base_name
@@ -1465,7 +1472,10 @@ class DreamBoothDataset(BaseDataset):
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
raise e
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
caption = lines[0].strip()
if enable_wildcard:
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
else:
caption = lines[0].strip()
break
return caption
@@ -1481,7 +1491,7 @@ class DreamBoothDataset(BaseDataset):
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension)
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
if cap_for_img is None and subset.class_tokens is None:
logger.warning(
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
@@ -1515,7 +1525,7 @@ class DreamBoothDataset(BaseDataset):
logger.info("prepare images.")
num_train_images = 0
num_reg_images = 0
reg_infos: List[ImageInfo] = []
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
for subset in subsets:
if subset.num_repeats < 1:
logger.warning(
@@ -1544,7 +1554,7 @@ class DreamBoothDataset(BaseDataset):
for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
if subset.is_reg:
reg_infos.append(info)
reg_infos.append((info, subset))
else:
self.register_image(info, subset)
@@ -1565,7 +1575,7 @@ class DreamBoothDataset(BaseDataset):
n = 0
first_loop = True
while n < num_train_images:
for info in reg_infos:
for info, subset in reg_infos:
if first_loop:
self.register_image(info, subset)
n += info.num_repeats
@@ -1657,10 +1667,24 @@ class FineTuningDataset(BaseDataset):
caption = img_md.get("caption")
tags = img_md.get("tags")
if caption is None:
caption = tags
elif tags is not None and len(tags) > 0:
caption = caption + ", " + tags
tags_list.append(tags)
caption = tags # could be multiline
tags = None
if subset.enable_wildcard:
# tags must be single line
if tags is not None:
tags = tags.replace("\n", subset.caption_separator)
# add tags to each line of caption
if caption is not None and tags is not None:
caption = "\n".join(
[f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""]
)
else:
# use as is
if tags is not None and len(tags) > 0:
caption = caption + subset.caption_separator + tags
tags_list.append(tags)
if caption is None:
caption = ""
@@ -3339,6 +3363,18 @@ def verify_training_args(args: argparse.Namespace):
+ " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります"
)
if args.sample_every_n_epochs is not None and args.sample_every_n_epochs <= 0:
logger.warning(
"sample_every_n_epochs is less than or equal to 0, so it will be disabled / sample_every_n_epochsに0以下の値が指定されたため無効になります"
)
args.sample_every_n_epochs = None
if args.sample_every_n_steps is not None and args.sample_every_n_steps <= 0:
logger.warning(
"sample_every_n_steps is less than or equal to 0, so it will be disabled / sample_every_n_stepsに0以下の値が指定されたため無効になります"
)
args.sample_every_n_steps = None
def add_dataset_arguments(
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool