fix duplicated sample gen for every epoch ref #907

This commit is contained in:
Kohya S
2023-12-07 22:13:38 +09:00
parent db84530074
commit 912dca8f65
6 changed files with 44 additions and 53 deletions

View File

@@ -272,13 +272,14 @@ def train(args):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train()
# train==True is required to enable gradient_checkpointing