From d05965dbadf430dab6a05f171292f6d2077ec946 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Wed, 13 Mar 2024 18:33:51 +0800 Subject: [PATCH] Update train_network.py --- train_network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 864bfd70..cc9fcbbe 100644 --- a/train_network.py +++ b/train_network.py @@ -987,8 +987,8 @@ class NetworkTrainer: accelerator.log(logs, step=global_step) if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or step == len(train_dataloader) - 1 or global_step >= args.max_train_steps: - print(f"\nValidating バリデーション処理...") + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) @@ -998,7 +998,7 @@ class NetworkTrainer: loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) total_loss += loss.detach().item() current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) if args.logging_dir is not None: logs = {"loss/current_val_loss": current_loss}