Improve wandb logging (#1576)

* fix: wrong training steps were recorded to wandb, and no log was sent when logging_dir was not specified

* fix: checking of whether wandb is enabled

* feat: log images to wandb with their positive prompt as captions

* feat: logging sample images' caption for sd3 and flux

* fix: import wandb before use
This commit is contained in:
Plat
2024-09-11 22:21:16 +09:00
committed by GitHub
parent d83f2e92da
commit a823fd9fb8
14 changed files with 80 additions and 49 deletions

View File

@@ -538,7 +538,7 @@ def train(args):
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
@@ -556,7 +556,7 @@ def train(args):
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)