added keep_tokens_separator to dynamically keep token for being shuffled

This commit is contained in:
Furqanil Taqwa
2023-11-28 17:23:55 +07:00
committed by GitHub
parent 4a913ce61e
commit 1624c239c2

View File

@@ -351,6 +351,7 @@ class BaseSubset:
shuffle_caption: bool, shuffle_caption: bool,
caption_separator: str, caption_separator: str,
keep_tokens: int, keep_tokens: int,
keep_tokens_separator: str,
color_aug: bool, color_aug: bool,
flip_aug: bool, flip_aug: bool,
face_crop_aug_range: Optional[Tuple[float, float]], face_crop_aug_range: Optional[Tuple[float, float]],
@@ -368,6 +369,7 @@ class BaseSubset:
self.shuffle_caption = shuffle_caption self.shuffle_caption = shuffle_caption
self.caption_separator = caption_separator self.caption_separator = caption_separator
self.keep_tokens = keep_tokens self.keep_tokens = keep_tokens
self.keep_tokens_separator = keep_tokens_separator
self.color_aug = color_aug self.color_aug = color_aug
self.flip_aug = flip_aug self.flip_aug = flip_aug
self.face_crop_aug_range = face_crop_aug_range self.face_crop_aug_range = face_crop_aug_range
@@ -395,6 +397,7 @@ class DreamBoothSubset(BaseSubset):
shuffle_caption, shuffle_caption,
caption_separator: str, caption_separator: str,
keep_tokens, keep_tokens,
keep_tokens_separator,
color_aug, color_aug,
flip_aug, flip_aug,
face_crop_aug_range, face_crop_aug_range,
@@ -415,6 +418,7 @@ class DreamBoothSubset(BaseSubset):
shuffle_caption, shuffle_caption,
caption_separator, caption_separator,
keep_tokens, keep_tokens,
keep_tokens_separator,
color_aug, color_aug,
flip_aug, flip_aug,
face_crop_aug_range, face_crop_aug_range,
@@ -449,6 +453,7 @@ class FineTuningSubset(BaseSubset):
shuffle_caption, shuffle_caption,
caption_separator, caption_separator,
keep_tokens, keep_tokens,
keep_tokens_separator,
color_aug, color_aug,
flip_aug, flip_aug,
face_crop_aug_range, face_crop_aug_range,
@@ -469,6 +474,7 @@ class FineTuningSubset(BaseSubset):
shuffle_caption, shuffle_caption,
caption_separator, caption_separator,
keep_tokens, keep_tokens,
keep_tokens_separator,
color_aug, color_aug,
flip_aug, flip_aug,
face_crop_aug_range, face_crop_aug_range,
@@ -500,6 +506,7 @@ class ControlNetSubset(BaseSubset):
shuffle_caption, shuffle_caption,
caption_separator, caption_separator,
keep_tokens, keep_tokens,
keep_tokens_separator,
color_aug, color_aug,
flip_aug, flip_aug,
face_crop_aug_range, face_crop_aug_range,
@@ -520,6 +527,7 @@ class ControlNetSubset(BaseSubset):
shuffle_caption, shuffle_caption,
caption_separator, caption_separator,
keep_tokens, keep_tokens,
keep_tokens_separator,
color_aug, color_aug,
flip_aug, flip_aug,
face_crop_aug_range, face_crop_aug_range,
@@ -654,15 +662,29 @@ class BaseDataset(torch.utils.data.Dataset):
caption = "" caption = ""
else: else:
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)] fixed_tokens = []
flex_tokens = []
if hasattr(subset, 'keep_tokens_separator') and subset.keep_tokens_separator in caption:
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
else:
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
flex_tokens = tokens[:]
if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[:subset.keep_tokens]
flex_tokens = tokens[subset.keep_tokens:]
if subset.token_warmup_step < 1: # 初回に上書きする if subset.token_warmup_step < 1: # 初回に上書きする
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step: if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = ( tokens_len = (
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step))) math.floor((self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
+ subset.token_warmup_min + subset.token_warmup_min
) )
tokens = tokens[:tokens_len] flex_tokens = flex_tokens[:tokens_len]
def dropout_tags(tokens): def dropout_tags(tokens):
if subset.caption_tag_dropout_rate <= 0: if subset.caption_tag_dropout_rate <= 0:
@@ -673,12 +695,6 @@ class BaseDataset(torch.utils.data.Dataset):
l.append(token) l.append(token)
return l return l
fixed_tokens = []
flex_tokens = tokens[:]
if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = tokens[subset.keep_tokens :]
if subset.shuffle_caption: if subset.shuffle_caption:
random.shuffle(flex_tokens) random.shuffle(flex_tokens)
@@ -697,6 +713,7 @@ class BaseDataset(torch.utils.data.Dataset):
else: else:
caption = caption.replace(str_from, str_to) caption = caption.replace(str_from, str_to)
print(caption)
return caption return caption
def get_input_ids(self, caption, tokenizer=None): def get_input_ids(self, caption, tokenizer=None):
@@ -1723,6 +1740,7 @@ class ControlNetDataset(BaseDataset):
subset.num_repeats, subset.num_repeats,
subset.shuffle_caption, subset.shuffle_caption,
subset.keep_tokens, subset.keep_tokens,
subset.keep_tokens_separator,
subset.color_aug, subset.color_aug,
subset.flip_aug, subset.flip_aug,
subset.face_crop_aug_range, subset.face_crop_aug_range,
@@ -3133,6 +3151,12 @@ def add_dataset_arguments(
default=0, default=0,
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残すトークンはカンマ区切りの各部分を意味する", help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残すトークンはカンマ区切りの各部分を意味する",
) )
parser.add_argument(
"--keep_tokens_separator",
type=str,
default="",
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens.",
)
parser.add_argument( parser.add_argument(
"--caption_prefix", "--caption_prefix",
type=str, type=str,