mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge pull request #975 from Linaqruf/dev
Add keep_tokens_separator as alternative for keep_tokens
This commit is contained in:
@@ -53,6 +53,7 @@ class BaseSubsetParams:
|
|||||||
shuffle_caption: bool = False
|
shuffle_caption: bool = False
|
||||||
caption_separator: str = ',',
|
caption_separator: str = ',',
|
||||||
keep_tokens: int = 0
|
keep_tokens: int = 0
|
||||||
|
keep_tokens_separator: str = None,
|
||||||
color_aug: bool = False
|
color_aug: bool = False
|
||||||
flip_aug: bool = False
|
flip_aug: bool = False
|
||||||
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
||||||
@@ -160,6 +161,7 @@ class ConfigSanitizer:
|
|||||||
"random_crop": bool,
|
"random_crop": bool,
|
||||||
"shuffle_caption": bool,
|
"shuffle_caption": bool,
|
||||||
"keep_tokens": int,
|
"keep_tokens": int,
|
||||||
|
"keep_tokens_separator": str,
|
||||||
"token_warmup_min": int,
|
"token_warmup_min": int,
|
||||||
"token_warmup_step": Any(float,int),
|
"token_warmup_step": Any(float,int),
|
||||||
"caption_prefix": str,
|
"caption_prefix": str,
|
||||||
@@ -461,6 +463,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
|||||||
num_repeats: {subset.num_repeats}
|
num_repeats: {subset.num_repeats}
|
||||||
shuffle_caption: {subset.shuffle_caption}
|
shuffle_caption: {subset.shuffle_caption}
|
||||||
keep_tokens: {subset.keep_tokens}
|
keep_tokens: {subset.keep_tokens}
|
||||||
|
keep_tokens_separator: {subset.keep_tokens_separator}
|
||||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||||
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
||||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||||
|
|||||||
@@ -351,6 +351,7 @@ class BaseSubset:
|
|||||||
shuffle_caption: bool,
|
shuffle_caption: bool,
|
||||||
caption_separator: str,
|
caption_separator: str,
|
||||||
keep_tokens: int,
|
keep_tokens: int,
|
||||||
|
keep_tokens_separator: str,
|
||||||
color_aug: bool,
|
color_aug: bool,
|
||||||
flip_aug: bool,
|
flip_aug: bool,
|
||||||
face_crop_aug_range: Optional[Tuple[float, float]],
|
face_crop_aug_range: Optional[Tuple[float, float]],
|
||||||
@@ -368,6 +369,7 @@ class BaseSubset:
|
|||||||
self.shuffle_caption = shuffle_caption
|
self.shuffle_caption = shuffle_caption
|
||||||
self.caption_separator = caption_separator
|
self.caption_separator = caption_separator
|
||||||
self.keep_tokens = keep_tokens
|
self.keep_tokens = keep_tokens
|
||||||
|
self.keep_tokens_separator = keep_tokens_separator
|
||||||
self.color_aug = color_aug
|
self.color_aug = color_aug
|
||||||
self.flip_aug = flip_aug
|
self.flip_aug = flip_aug
|
||||||
self.face_crop_aug_range = face_crop_aug_range
|
self.face_crop_aug_range = face_crop_aug_range
|
||||||
@@ -395,6 +397,7 @@ class DreamBoothSubset(BaseSubset):
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
caption_separator: str,
|
caption_separator: str,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
face_crop_aug_range,
|
face_crop_aug_range,
|
||||||
@@ -415,6 +418,7 @@ class DreamBoothSubset(BaseSubset):
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
caption_separator,
|
caption_separator,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
face_crop_aug_range,
|
face_crop_aug_range,
|
||||||
@@ -449,6 +453,7 @@ class FineTuningSubset(BaseSubset):
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
caption_separator,
|
caption_separator,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
face_crop_aug_range,
|
face_crop_aug_range,
|
||||||
@@ -469,6 +474,7 @@ class FineTuningSubset(BaseSubset):
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
caption_separator,
|
caption_separator,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
face_crop_aug_range,
|
face_crop_aug_range,
|
||||||
@@ -500,6 +506,7 @@ class ControlNetSubset(BaseSubset):
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
caption_separator,
|
caption_separator,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
face_crop_aug_range,
|
face_crop_aug_range,
|
||||||
@@ -520,6 +527,7 @@ class ControlNetSubset(BaseSubset):
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
caption_separator,
|
caption_separator,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
face_crop_aug_range,
|
face_crop_aug_range,
|
||||||
@@ -654,15 +662,29 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
caption = ""
|
caption = ""
|
||||||
else:
|
else:
|
||||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||||
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
|
fixed_tokens = []
|
||||||
|
flex_tokens = []
|
||||||
|
if hasattr(subset, 'keep_tokens_separator') and subset.keep_tokens_separator in caption:
|
||||||
|
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
|
||||||
|
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
|
||||||
|
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
|
||||||
|
else:
|
||||||
|
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
|
||||||
|
flex_tokens = tokens[:]
|
||||||
|
if subset.keep_tokens > 0:
|
||||||
|
fixed_tokens = flex_tokens[:subset.keep_tokens]
|
||||||
|
flex_tokens = tokens[subset.keep_tokens:]
|
||||||
|
|
||||||
|
|
||||||
if subset.token_warmup_step < 1: # 初回に上書きする
|
if subset.token_warmup_step < 1: # 初回に上書きする
|
||||||
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||||
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||||
tokens_len = (
|
tokens_len = (
|
||||||
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
|
math.floor((self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
|
||||||
+ subset.token_warmup_min
|
+ subset.token_warmup_min
|
||||||
)
|
)
|
||||||
tokens = tokens[:tokens_len]
|
flex_tokens = flex_tokens[:tokens_len]
|
||||||
|
|
||||||
|
|
||||||
def dropout_tags(tokens):
|
def dropout_tags(tokens):
|
||||||
if subset.caption_tag_dropout_rate <= 0:
|
if subset.caption_tag_dropout_rate <= 0:
|
||||||
@@ -673,12 +695,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
l.append(token)
|
l.append(token)
|
||||||
return l
|
return l
|
||||||
|
|
||||||
fixed_tokens = []
|
|
||||||
flex_tokens = tokens[:]
|
|
||||||
if subset.keep_tokens > 0:
|
|
||||||
fixed_tokens = flex_tokens[: subset.keep_tokens]
|
|
||||||
flex_tokens = tokens[subset.keep_tokens :]
|
|
||||||
|
|
||||||
if subset.shuffle_caption:
|
if subset.shuffle_caption:
|
||||||
random.shuffle(flex_tokens)
|
random.shuffle(flex_tokens)
|
||||||
|
|
||||||
@@ -1724,6 +1740,7 @@ class ControlNetDataset(BaseDataset):
|
|||||||
subset.shuffle_caption,
|
subset.shuffle_caption,
|
||||||
subset.caption_separator,
|
subset.caption_separator,
|
||||||
subset.keep_tokens,
|
subset.keep_tokens,
|
||||||
|
subset.keep_tokens_separator,
|
||||||
subset.color_aug,
|
subset.color_aug,
|
||||||
subset.flip_aug,
|
subset.flip_aug,
|
||||||
subset.face_crop_aug_range,
|
subset.face_crop_aug_range,
|
||||||
@@ -3131,6 +3148,12 @@ def add_dataset_arguments(
|
|||||||
default=0,
|
default=0,
|
||||||
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)",
|
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--keep_tokens_separator",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--caption_prefix",
|
"--caption_prefix",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
Reference in New Issue
Block a user