From 1b89b2a10e1f623efd3945d422dcd0640ac4f0fd Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 24 Mar 2023 13:44:30 +0900 Subject: [PATCH] =?UTF-8?q?=E3=82=B7=E3=83=A3=E3=83=83=E3=83=95=E3=83=AB?= =?UTF-8?q?=E5=89=8D=E3=81=AB=E3=82=BF=E3=82=B0=E3=82=92=E5=88=87=E3=82=8A?= =?UTF-8?q?=E8=A9=B0=E3=82=81=E3=82=8B=E3=82=88=E3=81=86=E3=81=AB=E5=A4=89?= =?UTF-8?q?=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- library/train_util.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 83e9372b..d1df9c58 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -477,6 +477,13 @@ class BaseDataset(torch.utils.data.Dataset): else: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: + tokens = [t.strip() for t in caption.strip().split(",")] + if subset.token_warmup_step < 1: + subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) + if subset.token_warmup_step and self.current_step < subset.token_warmup_step: + tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min + tokens = tokens[:tokens_len] + def dropout_tags(tokens): if subset.caption_tag_dropout_rate <= 0: return tokens @@ -487,24 +494,17 @@ class BaseDataset(torch.utils.data.Dataset): return l fixed_tokens = [] - flex_tokens = [t.strip() for t in caption.strip().split(",")] + flex_tokens = tokens[:] if subset.keep_tokens > 0: fixed_tokens = flex_tokens[: subset.keep_tokens] - flex_tokens = flex_tokens[subset.keep_tokens :] + flex_tokens = tokens[subset.keep_tokens :] if subset.shuffle_caption: random.shuffle(flex_tokens) flex_tokens = dropout_tags(flex_tokens) - tokens = fixed_tokens + flex_tokens - if subset.token_warmup_step < 1: - subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) - if subset.token_warmup_step and self.current_step < subset.token_warmup_step: - tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min - tokens = tokens[:tokens_len] - - caption = ", ".join(tokens) + caption = ", ".join(fixed_tokens + flex_tokens) # textual inversion対応 for str_from, str_to in self.replacements.items():