From 4a369961346ca153a370728247449978d8a33415 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 18 Feb 2025 22:05:08 +0900 Subject: [PATCH] modify log step calculation --- train_network.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/train_network.py b/train_network.py index 47c4bb56..93558da4 100644 --- a/train_network.py +++ b/train_network.py @@ -1464,11 +1464,10 @@ class NetworkTrainer: ) if is_tracking: - logs = { - "loss/validation/step_current": current_loss, - "val_step": (epoch * validation_total_steps) + val_ts_step, - } - accelerator.log(logs, step=global_step) + logs = {"loss/validation/step_current": current_loss} + accelerator.log( + logs, step=global_step + val_ts_step + ) # a bit weird to log with global_step + val_ts_step self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 @@ -1545,25 +1544,20 @@ class NetworkTrainer: ) if is_tracking: - logs = { - "loss/validation/epoch_current": current_loss, - "epoch": epoch + 1, - "val_step": (epoch * validation_total_steps) + val_ts_step, - } - accelerator.log(logs, step=global_step) + logs = {"loss/validation/epoch_current": current_loss} + accelerator.log(logs, step=global_step + val_ts_step) self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average + loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average logs = { "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, - "epoch": epoch + 1, } - accelerator.log(logs, step=global_step) + accelerator.log(logs, step=epoch + 1) restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep @@ -1574,8 +1568,8 @@ class NetworkTrainer: # END OF EPOCH if is_tracking: - logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} - accelerator.log(logs, step=global_step) + logs = {"loss/epoch_average": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone()