diff --git a/train_network.py b/train_network.py index 5a80d825..f3c8d8c9 100644 --- a/train_network.py +++ b/train_network.py @@ -1199,7 +1199,8 @@ class NetworkTrainer: ) loss_recorder = train_util.LossRecorder() - val_loss_recorder = train_util.LossRecorder() + val_step_loss_recorder = train_util.LossRecorder() + val_epoch_loss_recorder = train_util.LossRecorder() del train_dataset_group if val_dataset_group is not None: @@ -1299,7 +1300,8 @@ class NetworkTrainer: # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - loss = self.process_batch(batch, + loss = self.process_batch( + batch, text_encoders, unet, network, @@ -1373,15 +1375,25 @@ class NetworkTrainer: if is_tracking: logs = self.generate_step_logs( - args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm + args, + current_loss, + avr_loss, + lr_scheduler, + lr_descriptions, + optimizer, + keys_scaled, + mean_norm, + maximum_norm ) # accelerator.log(logs, step=global_step) accelerator.log(logs) # VALIDATION PER STEP - should_validate = (args.validation_every_n_step is not None - and global_step % args.validation_every_n_step == 0) - if validation_steps > 0 and should_validate: + should_validate_epoch = ( + args.validate_every_n_steps is not None + and global_step % args.validate_every_n_steps == 0 + ) + if validation_steps > 0 and should_validate_epoch: accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( @@ -1409,16 +1421,17 @@ class NetworkTrainer: is_train=False ) - val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) + current_loss = loss.detach().item() + val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) + val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) if is_tracking: - logs = {"loss/current_val_loss": loss.detach().item()} + logs = {"loss/step_validation_current": current_loss} # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) accelerator.log(logs) - logs = {"loss/average_val_loss": val_loss_recorder.moving_average} + logs = {"loss/step_validation_average": val_step_loss_recorder.moving_average} # accelerator.log(logs, step=global_step) accelerator.log(logs) @@ -1426,12 +1439,18 @@ class NetworkTrainer: break # VALIDATION EPOCH - if len(val_dataloader) > 0: + should_validate_epoch = ( + (epoch + 1) % args.validate_every_n_epochs == 0 + if args.validate_every_n_epochs is not None + else False + ) + + if should_validate_epoch and len(val_dataloader) > 0: accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, - desc="validation steps" + desc="epoch validation steps" ) for val_step, batch in enumerate(val_dataloader): @@ -1455,18 +1474,18 @@ class NetworkTrainer: ) current_loss = loss.detach().item() - val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) + val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average }) if is_tracking: - logs = {"loss/validation_current": current_loss} + logs = {"loss/epoch_validation_current": current_loss} # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) accelerator.log(logs) if is_tracking: - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation_average": avr_loss} + avr_loss: float = val_epoch_loss_recorder.moving_average + logs = {"loss/epoch_validation_average": avr_loss} # accelerator.log(logs, step=epoch + 1) accelerator.log(logs) @@ -1475,12 +1494,6 @@ class NetworkTrainer: logs = {"loss/epoch_average": loss_recorder.moving_average} # accelerator.log(logs, step=epoch + 1) accelerator.log(logs) - - if len(val_dataloader) > 0 and is_tracking: - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation_epoch_average": avr_loss} - # accelerator.log(logs, step=epoch + 1) - accelerator.log(logs) accelerator.wait_for_everyone() @@ -1676,10 +1689,16 @@ def setup_parser() -> argparse.ArgumentParser: help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" ) parser.add_argument( - "--validation_every_n_step", + "--validate_every_n_steps", type=int, default=None, - help="Number of train steps for counting validation loss. By default, validation per train epoch is performed / 学習エポックごとに検証を行う場合はNoneを指定する" + help="Run validation dataset every N steps" + ) + parser.add_argument( + "--validate_every_n_epochs", + type=int, + default=None, + help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available" ) parser.add_argument( "--max_validation_steps",