mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Improve wandb logging (#1576)
* fix: wrong training steps were recorded to wandb, and no log was sent when logging_dir was not specified * fix: checking of whether wandb is enabled * feat: log images to wandb with their positive prompt as captions * feat: logging sample images' caption for sd3 and flux * fix: import wandb before use
This commit is contained in:
@@ -550,6 +550,9 @@ class TextualInversionTrainer:
|
||||
unet,
|
||||
prompt_replacement,
|
||||
)
|
||||
if len(accelerator.trackers) > 0:
|
||||
# log empty object to commit the sample images to wandb
|
||||
accelerator.log({}, step=0)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
@@ -684,7 +687,7 @@ class TextualInversionTrainer:
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if (
|
||||
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
|
||||
@@ -702,7 +705,7 @@ class TextualInversionTrainer:
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
@@ -739,6 +742,7 @@ class TextualInversionTrainer:
|
||||
unet,
|
||||
prompt_replacement,
|
||||
)
|
||||
accelerator.log({})
|
||||
|
||||
# end of epoch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user