conditional caption dropout (in progress)

This commit is contained in:
Kohya S
2023-02-07 22:28:56 +09:00
parent f9478f0d47
commit e42b2f7aa9
5 changed files with 40 additions and 28 deletions

View File

@@ -38,8 +38,13 @@ def train(args):
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
if args.no_token_padding:
train_dataset.disable_token_padding()
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs)
train_dataset.make_buckets()
if args.debug_dataset:
@@ -136,10 +141,6 @@ def train(args):
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
# 学習データのdropout率を設定する
train_dataset.dropout_rate = args.dropout_rate
train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
@@ -333,7 +334,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, False)
train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True)
train_util.add_sd_saving_arguments(parser)