mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
conditional caption dropout (in progress)
This commit is contained in:
@@ -132,6 +132,10 @@ def train(args):
|
||||
args.bucket_reso_steps, args.bucket_no_upscale,
|
||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
||||
args.dataset_repeats, args.debug_dataset)
|
||||
|
||||
# 学習データのdropout率を設定する
|
||||
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs)
|
||||
|
||||
train_dataset.make_buckets()
|
||||
|
||||
if args.debug_dataset:
|
||||
@@ -219,10 +223,6 @@ def train(args):
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
||||
|
||||
# 学習データのdropout率を設定する
|
||||
train_dataset.dropout_rate = args.dropout_rate
|
||||
train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||
@@ -516,7 +516,7 @@ if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
|
||||
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
||||
|
||||
Reference in New Issue
Block a user