diff --git a/train_network.py b/train_network.py index 6eefdb2b..128690fb 100644 --- a/train_network.py +++ b/train_network.py @@ -981,20 +981,19 @@ class NetworkTrainer: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if args.validation_every_n_step is not None: - if global_step % (args.validation_every_n_step) == 0: - if len(val_dataloader) > 0: + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or step == len(train_dataloader) - 1 or global_step >= args.max_train_steps: print(f"\nValidating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): - validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) for val_step in tqdm(range(validation_steps), desc='Validation Steps'): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) total_loss += loss.detach().item() - current_loss = total_loss / args.validation_batches - val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=epoch, step=step, loss=current_loss) if args.logging_dir is not None: logs = {"loss/current_val_loss": current_loss} @@ -1009,25 +1008,6 @@ class NetworkTrainer: if args.logging_dir is not None: logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) - - if args.validation_every_n_step is None: - if len(val_dataloader) > 0: - print(f"\nValidating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - is_train = False - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / args.validation_batches - val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/epoch_val_average": avr_loss} - accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() @@ -1184,14 +1164,14 @@ def setup_parser() -> argparse.ArgumentParser: "--validation_every_n_step", type=int, default=None, - help="Number of steps for counting validation loss. By default, validation per epoch is performed" + help="Number of train steps for counting validation loss. By default, validation per train epoch is performed" ) parser.add_argument( - "--validation_batches", + "--max_validation_steps", type=int, default=None, - help="Number of val steps for counting validation loss. By default, validation for all val_dataset is performed" - ) + help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset" + ) return parser