fix duplicated sample gen for every epoch ref #907

This commit is contained in:
Kohya S
2023-12-07 22:13:38 +09:00
parent db84530074
commit 912dca8f65
6 changed files with 44 additions and 53 deletions

View File

@@ -409,9 +409,7 @@ class NetworkTrainer:
else:
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, lr_scheduler
)
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
@@ -725,6 +723,9 @@ class NetworkTrainer:
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
@@ -732,8 +733,6 @@ class NetworkTrainer:
metadata["ss_epoch"] = str(epoch + 1)
# For --sample_at_first
self.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
for step, batch in enumerate(train_dataloader):
@@ -807,7 +806,7 @@ class NetworkTrainer:
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
self.all_reduce_network(accelerator, network) # sync DDP grad manually
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)