Merge pull request #159 from forestsource/main

Add Conditional Dropout options
This commit is contained in:
Kohya S
2023-02-07 21:50:26 +09:00
committed by GitHub
4 changed files with 43 additions and 5 deletions

View File

@@ -223,6 +223,10 @@ class BaseDataset(torch.utils.data.Dataset):
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
self.epoch_current:int = int(0)
self.dropout_rate:float = 0
self.dropout_every_n_epochs:int = 0
# augmentation
flip_p = 0.5 if flip_aug else 0.0
if color_aug:
@@ -598,7 +602,17 @@ class BaseDataset(torch.utils.data.Dataset):
images.append(image)
latents_list.append(latents)
caption = self.process_caption(image_info.caption)
# dropoutの決定
is_drop_out = False
if self.dropout_rate > 0 and self.dropout_rate < random.random() :
is_drop_out = True
if self.dropout_every_n_epochs > 0 and self.epoch_current % self.dropout_every_n_epochs == 0 :
is_drop_out = True
if is_drop_out:
caption = ""
else:
caption = self.process_caption(image_info.caption)
captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future
input_ids_list.append(self.get_input_ids(caption))
@@ -1407,6 +1421,10 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
parser.add_argument("--bucket_no_upscale", action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
parser.add_argument("--dropout_rate", type=float, default=0,
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
parser.add_argument("--dropout_every_n_epochs", type=int, default=0,
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
if support_dreambooth:
# DreamBooth dataset