Fix W&B logging bugs

This commit is contained in:
Duoong
2026-02-08 06:41:19 +07:00
parent d2d111e826
commit bc6051fdc2

View File

@@ -578,6 +578,11 @@ def train(args):
init_kwargs=init_kwargs,
)
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
import wandb
wandb.define_metric("epoch")
wandb.define_metric("loss/epoch", step_metric="epoch")
if is_swapping_blocks:
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
@@ -782,8 +787,8 @@ def train(args):
break
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
logs = {"loss/epoch": loss_recorder.moving_average, "epoch": epoch + 1}
accelerator.log(logs, step=global_step)
accelerator.wait_for_everyone()