Add dropout options

This commit is contained in:
forestsource
2023-02-07 00:01:30 +09:00
parent d591891048
commit 7db98baa86
4 changed files with 43 additions and 5 deletions

View File

@@ -171,6 +171,10 @@ def train(args):
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# 学習データのdropout率を設定する
train_dataset.dropout_rate = args.dropout_rate
train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
@@ -226,6 +230,9 @@ def train(args):
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.epoch_current = epoch + 1
for m in training_models:
m.train()