mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'dev' into deep-speed
This commit is contained in:
@@ -694,6 +694,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
# process wildcards
|
||||
if subset.enable_wildcard:
|
||||
# if caption is multiline, random choice one line
|
||||
if "\n" in caption:
|
||||
caption = random.choice(caption.split("\n"))
|
||||
|
||||
# wildcard is like '{aaa|bbb|ccc...}'
|
||||
# escape the curly braces like {{ or }}
|
||||
replacer1 = "⦅"
|
||||
@@ -712,6 +716,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# unescape the curly braces
|
||||
caption = caption.replace(replacer1, "{").replace(replacer2, "}")
|
||||
else:
|
||||
# if caption is multiline, use the first line
|
||||
caption = caption.split("\n")[0]
|
||||
|
||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||
fixed_tokens = []
|
||||
@@ -1447,7 +1454,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
def read_caption(img_path, caption_extension):
|
||||
def read_caption(img_path, caption_extension, enable_wildcard):
|
||||
# captionの候補ファイル名を作る
|
||||
base_name = os.path.splitext(img_path)[0]
|
||||
base_name_face_det = base_name
|
||||
@@ -1466,7 +1473,10 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
||||
raise e
|
||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||
caption = lines[0].strip()
|
||||
if enable_wildcard:
|
||||
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
|
||||
else:
|
||||
caption = lines[0].strip()
|
||||
break
|
||||
return caption
|
||||
|
||||
@@ -1482,7 +1492,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
captions = []
|
||||
missing_captions = []
|
||||
for img_path in img_paths:
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension)
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
|
||||
if cap_for_img is None and subset.class_tokens is None:
|
||||
logger.warning(
|
||||
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
|
||||
@@ -1516,7 +1526,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.info("prepare images.")
|
||||
num_train_images = 0
|
||||
num_reg_images = 0
|
||||
reg_infos: List[ImageInfo] = []
|
||||
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
|
||||
for subset in subsets:
|
||||
if subset.num_repeats < 1:
|
||||
logger.warning(
|
||||
@@ -1545,7 +1555,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
for img_path, caption in zip(img_paths, captions):
|
||||
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
||||
if subset.is_reg:
|
||||
reg_infos.append(info)
|
||||
reg_infos.append((info, subset))
|
||||
else:
|
||||
self.register_image(info, subset)
|
||||
|
||||
@@ -1566,7 +1576,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
n = 0
|
||||
first_loop = True
|
||||
while n < num_train_images:
|
||||
for info in reg_infos:
|
||||
for info, subset in reg_infos:
|
||||
if first_loop:
|
||||
self.register_image(info, subset)
|
||||
n += info.num_repeats
|
||||
@@ -1658,10 +1668,24 @@ class FineTuningDataset(BaseDataset):
|
||||
caption = img_md.get("caption")
|
||||
tags = img_md.get("tags")
|
||||
if caption is None:
|
||||
caption = tags
|
||||
elif tags is not None and len(tags) > 0:
|
||||
caption = caption + ", " + tags
|
||||
tags_list.append(tags)
|
||||
caption = tags # could be multiline
|
||||
tags = None
|
||||
|
||||
if subset.enable_wildcard:
|
||||
# tags must be single line
|
||||
if tags is not None:
|
||||
tags = tags.replace("\n", subset.caption_separator)
|
||||
|
||||
# add tags to each line of caption
|
||||
if caption is not None and tags is not None:
|
||||
caption = "\n".join(
|
||||
[f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""]
|
||||
)
|
||||
else:
|
||||
# use as is
|
||||
if tags is not None and len(tags) > 0:
|
||||
caption = caption + subset.caption_separator + tags
|
||||
tags_list.append(tags)
|
||||
|
||||
if caption is None:
|
||||
caption = ""
|
||||
@@ -3315,6 +3339,18 @@ def verify_training_args(args: argparse.Namespace):
|
||||
+ " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります"
|
||||
)
|
||||
|
||||
if args.sample_every_n_epochs is not None and args.sample_every_n_epochs <= 0:
|
||||
logger.warning(
|
||||
"sample_every_n_epochs is less than or equal to 0, so it will be disabled / sample_every_n_epochsに0以下の値が指定されたため無効になります"
|
||||
)
|
||||
args.sample_every_n_epochs = None
|
||||
|
||||
if args.sample_every_n_steps is not None and args.sample_every_n_steps <= 0:
|
||||
logger.warning(
|
||||
"sample_every_n_steps is less than or equal to 0, so it will be disabled / sample_every_n_stepsに0以下の値が指定されたため無効になります"
|
||||
)
|
||||
args.sample_every_n_steps = None
|
||||
|
||||
|
||||
def add_dataset_arguments(
|
||||
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
|
||||
|
||||
Reference in New Issue
Block a user