mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge remote-tracking branch 'kohya-ss/dev' into val
This commit is contained in:
@@ -378,6 +378,8 @@ class BaseSubset:
|
||||
caption_separator: str,
|
||||
keep_tokens: int,
|
||||
keep_tokens_separator: str,
|
||||
secondary_separator: Optional[str],
|
||||
enable_wildcard: bool,
|
||||
color_aug: bool,
|
||||
flip_aug: bool,
|
||||
face_crop_aug_range: Optional[Tuple[float, float]],
|
||||
@@ -396,6 +398,8 @@ class BaseSubset:
|
||||
self.caption_separator = caption_separator
|
||||
self.keep_tokens = keep_tokens
|
||||
self.keep_tokens_separator = keep_tokens_separator
|
||||
self.secondary_separator = secondary_separator
|
||||
self.enable_wildcard = enable_wildcard
|
||||
self.color_aug = color_aug
|
||||
self.flip_aug = flip_aug
|
||||
self.face_crop_aug_range = face_crop_aug_range
|
||||
@@ -424,6 +428,8 @@ class DreamBoothSubset(BaseSubset):
|
||||
caption_separator: str,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -445,6 +451,8 @@ class DreamBoothSubset(BaseSubset):
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -480,6 +488,8 @@ class FineTuningSubset(BaseSubset):
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -501,6 +511,8 @@ class FineTuningSubset(BaseSubset):
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -533,6 +545,8 @@ class ControlNetSubset(BaseSubset):
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -554,6 +568,8 @@ class ControlNetSubset(BaseSubset):
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -689,15 +705,41 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if is_drop_out:
|
||||
caption = ""
|
||||
else:
|
||||
# process wildcards
|
||||
if subset.enable_wildcard:
|
||||
# wildcard is like '{aaa|bbb|ccc...}'
|
||||
# escape the curly braces like {{ or }}
|
||||
replacer1 = "⦅"
|
||||
replacer2 = "⦆"
|
||||
while replacer1 in caption or replacer2 in caption:
|
||||
replacer1 += "⦅"
|
||||
replacer2 += "⦆"
|
||||
|
||||
caption = caption.replace("{{", replacer1).replace("}}", replacer2)
|
||||
|
||||
# replace the wildcard
|
||||
def replace_wildcard(match):
|
||||
return random.choice(match.group(1).split("|"))
|
||||
|
||||
caption = re.sub(r"\{([^}]+)\}", replace_wildcard, caption)
|
||||
|
||||
# unescape the curly braces
|
||||
caption = caption.replace(replacer1, "{").replace(replacer2, "}")
|
||||
|
||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||
fixed_tokens = []
|
||||
flex_tokens = []
|
||||
fixed_suffix_tokens = []
|
||||
if (
|
||||
hasattr(subset, "keep_tokens_separator")
|
||||
and subset.keep_tokens_separator
|
||||
and subset.keep_tokens_separator in caption
|
||||
):
|
||||
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
|
||||
if subset.keep_tokens_separator in flex_part:
|
||||
flex_part, fixed_suffix_part = flex_part.split(subset.keep_tokens_separator, 1)
|
||||
fixed_suffix_tokens = [t.strip() for t in fixed_suffix_part.split(subset.caption_separator) if t.strip()]
|
||||
|
||||
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
|
||||
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
|
||||
else:
|
||||
@@ -732,7 +774,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
flex_tokens = dropout_tags(flex_tokens)
|
||||
|
||||
caption = ", ".join(fixed_tokens + flex_tokens)
|
||||
caption = ", ".join(fixed_tokens + flex_tokens + fixed_suffix_tokens)
|
||||
|
||||
# process secondary separator
|
||||
if subset.secondary_separator:
|
||||
caption = caption.replace(subset.secondary_separator, subset.caption_separator)
|
||||
|
||||
# textual inversion対応
|
||||
for str_from, str_to in self.replacements.items():
|
||||
@@ -1796,6 +1842,8 @@ class ControlNetDataset(BaseDataset):
|
||||
subset.caption_separator,
|
||||
subset.keep_tokens,
|
||||
subset.keep_tokens_separator,
|
||||
subset.secondary_separator,
|
||||
subset.enable_wildcard,
|
||||
subset.color_aug,
|
||||
subset.flip_aug,
|
||||
subset.face_crop_aug_range,
|
||||
@@ -3306,6 +3354,18 @@ def add_dataset_arguments(
|
||||
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens."
|
||||
+ " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--secondary_separator",
|
||||
type=str,
|
||||
default=None,
|
||||
help="a secondary separator for caption. This separator is replaced to caption_separator after dropping/shuffling caption"
|
||||
+ " / captionのセカンダリ区切り文字。この区切り文字はcaptionのドロップやシャッフル後にcaption_separatorに置き換えられる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_wildcard",
|
||||
action="store_true",
|
||||
help="enable wildcard for caption (e.g. '{image|picture|rendition}') / captionのワイルドカードを有効にする(例:'{image|picture|rendition}')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_prefix",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user