From 4a913ce61edb7bb201c175ff4f9d641205e3eed2 Mon Sep 17 00:00:00 2001 From: Furqanil Taqwa <50163983+Linaqruf@users.noreply.github.com> Date: Tue, 28 Nov 2023 17:22:35 +0700 Subject: [PATCH 1/3] initialize keep_tokens_separator to dataset config --- library/config_util.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/library/config_util.py b/library/config_util.py index ab90fb63..47868f3b 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -53,6 +53,7 @@ class BaseSubsetParams: shuffle_caption: bool = False caption_separator: str = ',', keep_tokens: int = 0 + keep_tokens_separator: str = None, color_aug: bool = False flip_aug: bool = False face_crop_aug_range: Optional[Tuple[float, float]] = None @@ -160,6 +161,7 @@ class ConfigSanitizer: "random_crop": bool, "shuffle_caption": bool, "keep_tokens": int, + "keep_tokens_separator": str, "token_warmup_min": int, "token_warmup_step": Any(float,int), "caption_prefix": str, @@ -461,6 +463,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu num_repeats: {subset.num_repeats} shuffle_caption: {subset.shuffle_caption} keep_tokens: {subset.keep_tokens} + keep_tokens_separator: {subset.keep_tokens_separator} caption_dropout_rate: {subset.caption_dropout_rate} caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} From 1624c239c27362a3fbc03fa8c179ecf9ec17e24e Mon Sep 17 00:00:00 2001 From: Furqanil Taqwa <50163983+Linaqruf@users.noreply.github.com> Date: Tue, 28 Nov 2023 17:23:55 +0700 Subject: [PATCH 2/3] added keep_tokens_separator to dynamically keep token for being shuffled --- library/train_util.py | 42 +++++++++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 9fb616ed..5e152532 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -351,6 +351,7 @@ class BaseSubset: shuffle_caption: bool, caption_separator: str, keep_tokens: int, + keep_tokens_separator: str, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], @@ -368,6 +369,7 @@ class BaseSubset: self.shuffle_caption = shuffle_caption self.caption_separator = caption_separator self.keep_tokens = keep_tokens + self.keep_tokens_separator = keep_tokens_separator self.color_aug = color_aug self.flip_aug = flip_aug self.face_crop_aug_range = face_crop_aug_range @@ -395,6 +397,7 @@ class DreamBoothSubset(BaseSubset): shuffle_caption, caption_separator: str, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -415,6 +418,7 @@ class DreamBoothSubset(BaseSubset): shuffle_caption, caption_separator, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -449,6 +453,7 @@ class FineTuningSubset(BaseSubset): shuffle_caption, caption_separator, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -469,6 +474,7 @@ class FineTuningSubset(BaseSubset): shuffle_caption, caption_separator, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -500,6 +506,7 @@ class ControlNetSubset(BaseSubset): shuffle_caption, caption_separator, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -520,6 +527,7 @@ class ControlNetSubset(BaseSubset): shuffle_caption, caption_separator, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -654,15 +662,29 @@ class BaseDataset(torch.utils.data.Dataset): caption = "" else: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: - tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)] + fixed_tokens = [] + flex_tokens = [] + if hasattr(subset, 'keep_tokens_separator') and subset.keep_tokens_separator in caption: + fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1) + fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()] + flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()] + else: + tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)] + flex_tokens = tokens[:] + if subset.keep_tokens > 0: + fixed_tokens = flex_tokens[:subset.keep_tokens] + flex_tokens = tokens[subset.keep_tokens:] + + if subset.token_warmup_step < 1: # 初回に上書きする subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) if subset.token_warmup_step and self.current_step < subset.token_warmup_step: tokens_len = ( - math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step))) + math.floor((self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step))) + subset.token_warmup_min ) - tokens = tokens[:tokens_len] + flex_tokens = flex_tokens[:tokens_len] + def dropout_tags(tokens): if subset.caption_tag_dropout_rate <= 0: @@ -673,12 +695,6 @@ class BaseDataset(torch.utils.data.Dataset): l.append(token) return l - fixed_tokens = [] - flex_tokens = tokens[:] - if subset.keep_tokens > 0: - fixed_tokens = flex_tokens[: subset.keep_tokens] - flex_tokens = tokens[subset.keep_tokens :] - if subset.shuffle_caption: random.shuffle(flex_tokens) @@ -697,6 +713,7 @@ class BaseDataset(torch.utils.data.Dataset): else: caption = caption.replace(str_from, str_to) + print(caption) return caption def get_input_ids(self, caption, tokenizer=None): @@ -1723,6 +1740,7 @@ class ControlNetDataset(BaseDataset): subset.num_repeats, subset.shuffle_caption, subset.keep_tokens, + subset.keep_tokens_separator, subset.color_aug, subset.flip_aug, subset.face_crop_aug_range, @@ -3133,6 +3151,12 @@ def add_dataset_arguments( default=0, help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)", ) + parser.add_argument( + "--keep_tokens_separator", + type=str, + default="", + help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens.", + ) parser.add_argument( "--caption_prefix", type=str, From 1bdd83a85f9a381cf460b9a47a049cf68ceb67f0 Mon Sep 17 00:00:00 2001 From: Furqanil Taqwa <50163983+Linaqruf@users.noreply.github.com> Date: Tue, 28 Nov 2023 17:26:27 +0700 Subject: [PATCH 3/3] remove unnecessary debug print --- library/train_util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 5e152532..8ea7a438 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -713,7 +713,6 @@ class BaseDataset(torch.utils.data.Dataset): else: caption = caption.replace(str_from, str_to) - print(caption) return caption def get_input_ids(self, caption, tokenizer=None):