mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
シャッフル前にタグを切り詰めるように変更
This commit is contained in:
@@ -477,6 +477,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||
|
||||
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||
if subset.token_warmup_step < 1:
|
||||
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||
tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min
|
||||
tokens = tokens[:tokens_len]
|
||||
|
||||
def dropout_tags(tokens):
|
||||
if subset.caption_tag_dropout_rate <= 0:
|
||||
return tokens
|
||||
@@ -487,24 +494,17 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
return l
|
||||
|
||||
fixed_tokens = []
|
||||
flex_tokens = [t.strip() for t in caption.strip().split(",")]
|
||||
flex_tokens = tokens[:]
|
||||
if subset.keep_tokens > 0:
|
||||
fixed_tokens = flex_tokens[: subset.keep_tokens]
|
||||
flex_tokens = flex_tokens[subset.keep_tokens :]
|
||||
flex_tokens = tokens[subset.keep_tokens :]
|
||||
|
||||
if subset.shuffle_caption:
|
||||
random.shuffle(flex_tokens)
|
||||
|
||||
flex_tokens = dropout_tags(flex_tokens)
|
||||
tokens = fixed_tokens + flex_tokens
|
||||
|
||||
if subset.token_warmup_step < 1:
|
||||
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||
tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min
|
||||
tokens = tokens[:tokens_len]
|
||||
|
||||
caption = ", ".join(tokens)
|
||||
caption = ", ".join(fixed_tokens + flex_tokens)
|
||||
|
||||
# textual inversion対応
|
||||
for str_from, str_to in self.replacements.items():
|
||||
|
||||
Reference in New Issue
Block a user