シャッフル前にタグを切り詰めるように変更

This commit is contained in:
u-haru
2023-03-24 13:44:30 +09:00
parent 143c26e552
commit 1b89b2a10e

View File

@@ -477,6 +477,13 @@ class BaseDataset(torch.utils.data.Dataset):
else: else:
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
tokens = [t.strip() for t in caption.strip().split(",")]
if subset.token_warmup_step < 1:
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min
tokens = tokens[:tokens_len]
def dropout_tags(tokens): def dropout_tags(tokens):
if subset.caption_tag_dropout_rate <= 0: if subset.caption_tag_dropout_rate <= 0:
return tokens return tokens
@@ -487,24 +494,17 @@ class BaseDataset(torch.utils.data.Dataset):
return l return l
fixed_tokens = [] fixed_tokens = []
flex_tokens = [t.strip() for t in caption.strip().split(",")] flex_tokens = tokens[:]
if subset.keep_tokens > 0: if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens] fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = flex_tokens[subset.keep_tokens :] flex_tokens = tokens[subset.keep_tokens :]
if subset.shuffle_caption: if subset.shuffle_caption:
random.shuffle(flex_tokens) random.shuffle(flex_tokens)
flex_tokens = dropout_tags(flex_tokens) flex_tokens = dropout_tags(flex_tokens)
tokens = fixed_tokens + flex_tokens
if subset.token_warmup_step < 1: caption = ", ".join(fixed_tokens + flex_tokens)
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min
tokens = tokens[:tokens_len]
caption = ", ".join(tokens)
# textual inversion対応 # textual inversion対応
for str_from, str_to in self.replacements.items(): for str_from, str_to in self.replacements.items():