modify log step calculation

This commit is contained in:
Kohya S
2025-02-18 22:05:08 +09:00
parent dc7d5fb459
commit 4a36996134

View File

@@ -1464,11 +1464,10 @@ class NetworkTrainer:
) )
if is_tracking: if is_tracking:
logs = { logs = {"loss/validation/step_current": current_loss}
"loss/validation/step_current": current_loss, accelerator.log(
"val_step": (epoch * validation_total_steps) + val_ts_step, logs, step=global_step + val_ts_step
} ) # a bit weird to log with global_step + val_ts_step
accelerator.log(logs, step=global_step)
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
val_ts_step += 1 val_ts_step += 1
@@ -1545,12 +1544,8 @@ class NetworkTrainer:
) )
if is_tracking: if is_tracking:
logs = { logs = {"loss/validation/epoch_current": current_loss}
"loss/validation/epoch_current": current_loss, accelerator.log(logs, step=global_step + val_ts_step)
"epoch": epoch + 1,
"val_step": (epoch * validation_total_steps) + val_ts_step,
}
accelerator.log(logs, step=global_step)
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
val_ts_step += 1 val_ts_step += 1
@@ -1561,9 +1556,8 @@ class NetworkTrainer:
logs = { logs = {
"loss/validation/epoch_average": avr_loss, "loss/validation/epoch_average": avr_loss,
"loss/validation/epoch_divergence": loss_validation_divergence, "loss/validation/epoch_divergence": loss_validation_divergence,
"epoch": epoch + 1,
} }
accelerator.log(logs, step=global_step) accelerator.log(logs, step=epoch + 1)
restore_rng_state(rng_states) restore_rng_state(rng_states)
args.min_timestep = original_args_min_timestep args.min_timestep = original_args_min_timestep
@@ -1574,8 +1568,8 @@ class NetworkTrainer:
# END OF EPOCH # END OF EPOCH
if is_tracking: if is_tracking:
logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} logs = {"loss/epoch_average": loss_recorder.moving_average}
accelerator.log(logs, step=global_step) accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()