expand details in tensorboard logs

- Update tensorboard logging to track both unet and textencoder learning rates
- Update tensorboard logging to track both current and moving average epoch loss
- Clean up tensorboard log variable names for dashboard formatting
This commit is contained in:
michaelgzhang
2023-01-18 13:10:13 -06:00
parent 37fbefb3cd
commit 303c3410e2
2 changed files with 18 additions and 5 deletions

View File

@@ -330,20 +330,21 @@ def train(args):
global_step += 1
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
accelerator.log(logs, step=global_step)
loss_total += current_loss
avr_loss = loss_total / (step+1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
logs = train_util.generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"epoch_loss": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone()