Update train_network.py

This commit is contained in:
gesen2egee
2024-03-13 18:33:51 +08:00
parent 5d7ed0dff0
commit d05965dbad

View File

@@ -987,8 +987,8 @@ class NetworkTrainer:
accelerator.log(logs, step=global_step)
if len(val_dataloader) > 0:
if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or step == len(train_dataloader) - 1 or global_step >= args.max_train_steps:
print(f"\nValidating バリデーション処理...")
if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps:
accelerator.print("Validating バリデーション処理...")
total_loss = 0.0
with torch.no_grad():
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
@@ -998,7 +998,7 @@ class NetworkTrainer:
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
total_loss += loss.detach().item()
current_loss = total_loss / validation_steps
val_loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss)
if args.logging_dir is not None:
logs = {"loss/current_val_loss": current_loss}