mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
modify log step calculation
This commit is contained in:
@@ -1464,11 +1464,10 @@ class NetworkTrainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
logs = {
|
logs = {"loss/validation/step_current": current_loss}
|
||||||
"loss/validation/step_current": current_loss,
|
accelerator.log(
|
||||||
"val_step": (epoch * validation_total_steps) + val_ts_step,
|
logs, step=global_step + val_ts_step
|
||||||
}
|
) # a bit weird to log with global_step + val_ts_step
|
||||||
accelerator.log(logs, step=global_step)
|
|
||||||
|
|
||||||
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||||
val_ts_step += 1
|
val_ts_step += 1
|
||||||
@@ -1545,25 +1544,20 @@ class NetworkTrainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
logs = {
|
logs = {"loss/validation/epoch_current": current_loss}
|
||||||
"loss/validation/epoch_current": current_loss,
|
accelerator.log(logs, step=global_step + val_ts_step)
|
||||||
"epoch": epoch + 1,
|
|
||||||
"val_step": (epoch * validation_total_steps) + val_ts_step,
|
|
||||||
}
|
|
||||||
accelerator.log(logs, step=global_step)
|
|
||||||
|
|
||||||
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||||
val_ts_step += 1
|
val_ts_step += 1
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
avr_loss: float = val_epoch_loss_recorder.moving_average
|
avr_loss: float = val_epoch_loss_recorder.moving_average
|
||||||
loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average
|
loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average
|
||||||
logs = {
|
logs = {
|
||||||
"loss/validation/epoch_average": avr_loss,
|
"loss/validation/epoch_average": avr_loss,
|
||||||
"loss/validation/epoch_divergence": loss_validation_divergence,
|
"loss/validation/epoch_divergence": loss_validation_divergence,
|
||||||
"epoch": epoch + 1,
|
|
||||||
}
|
}
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
restore_rng_state(rng_states)
|
restore_rng_state(rng_states)
|
||||||
args.min_timestep = original_args_min_timestep
|
args.min_timestep = original_args_min_timestep
|
||||||
@@ -1574,8 +1568,8 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# END OF EPOCH
|
# END OF EPOCH
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1}
|
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user