mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
シャッフル前にタグを切り詰めるように変更
This commit is contained in:
@@ -477,6 +477,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
else:
|
else:
|
||||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||||
|
|
||||||
|
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||||
|
if subset.token_warmup_step < 1:
|
||||||
|
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||||
|
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||||
|
tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min
|
||||||
|
tokens = tokens[:tokens_len]
|
||||||
|
|
||||||
def dropout_tags(tokens):
|
def dropout_tags(tokens):
|
||||||
if subset.caption_tag_dropout_rate <= 0:
|
if subset.caption_tag_dropout_rate <= 0:
|
||||||
return tokens
|
return tokens
|
||||||
@@ -487,24 +494,17 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
return l
|
return l
|
||||||
|
|
||||||
fixed_tokens = []
|
fixed_tokens = []
|
||||||
flex_tokens = [t.strip() for t in caption.strip().split(",")]
|
flex_tokens = tokens[:]
|
||||||
if subset.keep_tokens > 0:
|
if subset.keep_tokens > 0:
|
||||||
fixed_tokens = flex_tokens[: subset.keep_tokens]
|
fixed_tokens = flex_tokens[: subset.keep_tokens]
|
||||||
flex_tokens = flex_tokens[subset.keep_tokens :]
|
flex_tokens = tokens[subset.keep_tokens :]
|
||||||
|
|
||||||
if subset.shuffle_caption:
|
if subset.shuffle_caption:
|
||||||
random.shuffle(flex_tokens)
|
random.shuffle(flex_tokens)
|
||||||
|
|
||||||
flex_tokens = dropout_tags(flex_tokens)
|
flex_tokens = dropout_tags(flex_tokens)
|
||||||
tokens = fixed_tokens + flex_tokens
|
|
||||||
|
|
||||||
if subset.token_warmup_step < 1:
|
caption = ", ".join(fixed_tokens + flex_tokens)
|
||||||
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
|
||||||
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
|
||||||
tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min
|
|
||||||
tokens = tokens[:tokens_len]
|
|
||||||
|
|
||||||
caption = ", ".join(tokens)
|
|
||||||
|
|
||||||
# textual inversion対応
|
# textual inversion対応
|
||||||
for str_from, str_to in self.replacements.items():
|
for str_from, str_to in self.replacements.items():
|
||||||
|
|||||||
Reference in New Issue
Block a user