diff --git a/fine_tune.py b/fine_tune.py index 6a95886c..17b89852 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -171,6 +171,10 @@ def train(args): args.max_train_steps = args.max_train_epochs * len(train_dataloader) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # 学習データのdropout率を設定する + train_dataset.dropout_rate = args.dropout_rate + train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs + # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler( args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) @@ -226,6 +230,9 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + + train_dataset.epoch_current = epoch + 1 + for m in training_models: m.train() diff --git a/library/train_util.py b/library/train_util.py index 6f809deb..10fc4416 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -223,6 +223,10 @@ class BaseDataset(torch.utils.data.Dataset): self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 + self.epoch_current:int = int(0) + self.dropout_rate:float = 0 + self.dropout_every_n_epochs:int = 0 + # augmentation flip_p = 0.5 if flip_aug else 0.0 if color_aug: @@ -598,7 +602,17 @@ class BaseDataset(torch.utils.data.Dataset): images.append(image) latents_list.append(latents) - caption = self.process_caption(image_info.caption) + # dropoutの決定 + is_drop_out = False + if self.dropout_rate > 0 and self.dropout_rate < random.random() : + is_drop_out = True + if self.dropout_every_n_epochs > 0 and self.epoch_current % self.dropout_every_n_epochs == 0 : + is_drop_out = True + + if is_drop_out: + caption = "" + else: + caption = self.process_caption(image_info.caption) captions.append(caption) if not self.token_padding_disabled: # this option might be omitted in future input_ids_list.append(self.get_input_ids(caption)) @@ -1407,6 +1421,10 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します") parser.add_argument("--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します") + parser.add_argument("--dropout_rate", type=float, default=0, + help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合") + parser.add_argument("--dropout_every_n_epochs", type=int, default=0, + help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする") if support_dreambooth: # DreamBooth dataset diff --git a/train_db.py b/train_db.py index d1bbc07f..96a4dde6 100644 --- a/train_db.py +++ b/train_db.py @@ -136,6 +136,10 @@ def train(args): train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) + # 学習データのdropout率を設定する + train_dataset.dropout_rate = args.dropout_rate + train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs + # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * len(train_dataloader) @@ -204,6 +208,8 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset.epoch_current = epoch + 1 + # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() # train==True is required to enable gradient_checkpointing diff --git a/train_network.py b/train_network.py index 3e8f4e7d..82ebeaf1 100644 --- a/train_network.py +++ b/train_network.py @@ -120,16 +120,16 @@ def train(args): print("Use DreamBooth method.") train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir, tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, - args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, + args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, + args.bucket_reso_steps, args.bucket_no_upscale, + args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) else: print("Train with captions.") train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.bucket_reso_steps, args.bucket_no_upscale, + args.bucket_reso_steps, args.bucket_no_upscale, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.dataset_repeats, args.debug_dataset) train_dataset.make_buckets() @@ -219,6 +219,10 @@ def train(args): train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) + # 学習データのdropout率を設定する + train_dataset.dropout_rate = args.dropout_rate + train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs + # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * len(train_dataloader) @@ -376,6 +380,9 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + + train_dataset.epoch_current = epoch + 1 + metadata["ss_epoch"] = str(epoch+1) network.on_epoch_start(text_encoder, unet)