fix wandb val logging

This commit is contained in:
Kohya S
2025-02-21 22:07:35 +09:00
parent 4a36996134
commit efb2a128cd
2 changed files with 80 additions and 50 deletions

View File

@@ -119,6 +119,45 @@ class NetworkTrainer:
return logs
def step_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
self.accelerator_logging(accelerator, logs, global_step, global_step, epoch)
def epoch_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
self.accelerator_logging(accelerator, logs, epoch, global_step, epoch)
def val_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int, val_step: int):
self.accelerator_logging(accelerator, logs, global_step + val_step, global_step, epoch, val_step)
def accelerator_logging(
self, accelerator: Accelerator, logs: dict, step_value: int, global_step: int, epoch: int, val_step: Optional[int] = None
):
"""
step_value is for tensorboard, other values are for wandb
"""
tensorboard_tracker = None
wandb_tracker = None
other_trackers = []
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
tensorboard_tracker = accelerator.get_tracker("tensorboard")
elif tracker.name == "wandb":
wandb_tracker = accelerator.get_tracker("wandb")
else:
other_trackers.append(accelerator.get_tracker(tracker.name))
if tensorboard_tracker is not None:
tensorboard_tracker.log(logs, step=step_value)
if wandb_tracker is not None:
logs["global_step"] = global_step
logs["epoch"] = epoch
if val_step is not None:
logs["val_step"] = val_step
wandb_tracker.log(logs)
for tracker in other_trackers:
tracker.log(logs, step=step_value)
def assert_extra_args(
self,
args,
@@ -1412,7 +1451,7 @@ class NetworkTrainer:
logs = self.generate_step_logs(
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
)
accelerator.log(logs, step=global_step)
self.step_logging(accelerator, logs, global_step, epoch + 1)
# VALIDATION PER STEP: global_step is already incremented
# for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...
@@ -1428,7 +1467,7 @@ class NetworkTrainer:
disable=not accelerator.is_local_main_process,
desc="validation steps",
)
val_ts_step = 0
val_timesteps_step = 0
for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
break
@@ -1457,20 +1496,18 @@ class NetworkTrainer:
)
current_loss = loss.detach().item()
val_step_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss)
val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix(
{"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep}
)
if is_tracking:
logs = {"loss/validation/step_current": current_loss}
accelerator.log(
logs, step=global_step + val_ts_step
) # a bit weird to log with global_step + val_ts_step
# if is_tracking:
# logs = {f"loss/validation/step_current_{timestep}": current_loss}
# self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
val_ts_step += 1
val_timesteps_step += 1
if is_tracking:
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
@@ -1478,7 +1515,7 @@ class NetworkTrainer:
"loss/validation/step_average": val_step_loss_recorder.moving_average,
"loss/validation/step_divergence": loss_validation_divergence,
}
accelerator.log(logs, step=global_step)
self.step_logging(accelerator, logs, global_step, epoch=epoch + 1)
restore_rng_state(rng_states)
args.min_timestep = original_args_min_timestep
@@ -1507,7 +1544,7 @@ class NetworkTrainer:
desc="epoch validation steps",
)
val_ts_step = 0
val_timesteps_step = 0
for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
break
@@ -1537,18 +1574,18 @@ class NetworkTrainer:
)
current_loss = loss.detach().item()
val_epoch_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss)
val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix(
{"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}
)
if is_tracking:
logs = {"loss/validation/epoch_current": current_loss}
accelerator.log(logs, step=global_step + val_ts_step)
# if is_tracking:
# logs = {f"loss/validation/epoch_current_{timestep}": current_loss}
# self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
val_ts_step += 1
val_timesteps_step += 1
if is_tracking:
avr_loss: float = val_epoch_loss_recorder.moving_average
@@ -1557,7 +1594,7 @@ class NetworkTrainer:
"loss/validation/epoch_average": avr_loss,
"loss/validation/epoch_divergence": loss_validation_divergence,
}
accelerator.log(logs, step=epoch + 1)
self.epoch_logging(accelerator, logs, global_step, epoch + 1)
restore_rng_state(rng_states)
args.min_timestep = original_args_min_timestep
@@ -1569,7 +1606,7 @@ class NetworkTrainer:
# END OF EPOCH
if is_tracking:
logs = {"loss/epoch_average": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
self.epoch_logging(accelerator, logs, global_step, epoch + 1)
accelerator.wait_for_everyone()