mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Fix crash TI train close #172, tag drop wo shuffle
This commit is contained in:
@@ -226,6 +226,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
||||
self.dropout_rate: float = 0
|
||||
self.dropout_every_n_epochs: int = None
|
||||
self.tag_dropout_rate: float = 0
|
||||
|
||||
# augmentation
|
||||
flip_p = 0.5 if flip_aug else 0.0
|
||||
@@ -284,7 +285,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if is_drop_out:
|
||||
caption = ""
|
||||
else:
|
||||
if self.shuffle_caption:
|
||||
if self.shuffle_caption or self.tag_dropout_rate > 0:
|
||||
def dropout_tags(tokens):
|
||||
if self.tag_dropout_rate <= 0:
|
||||
return tokens
|
||||
@@ -296,13 +297,18 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||
if self.shuffle_keep_tokens is None:
|
||||
random.shuffle(tokens)
|
||||
if self.shuffle_caption:
|
||||
random.shuffle(tokens)
|
||||
|
||||
tokens = dropout_tags(tokens)
|
||||
else:
|
||||
if len(tokens) > self.shuffle_keep_tokens:
|
||||
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
||||
tokens = tokens[self.shuffle_keep_tokens:]
|
||||
random.shuffle(tokens)
|
||||
|
||||
if self.shuffle_caption:
|
||||
random.shuffle(tokens)
|
||||
|
||||
tokens = dropout_tags(tokens)
|
||||
|
||||
tokens = keep_tokens + tokens
|
||||
|
||||
Reference in New Issue
Block a user