Improve wandb logging (#1576)

* fix: wrong training steps were recorded to wandb, and no log was sent when logging_dir was not specified

* fix: checking of whether wandb is enabled

* feat: log images to wandb with their positive prompt as captions

* feat: logging sample images' caption for sd3 and flux

* fix: import wandb before use
This commit is contained in:
Plat
2024-09-11 22:21:16 +09:00
committed by GitHub
parent d83f2e92da
commit a823fd9fb8
14 changed files with 80 additions and 49 deletions

View File

@@ -1038,6 +1038,9 @@ class NetworkTrainer:
# For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
# training loop
if initial_step > 0: # only if skip_until_initial_step is specified
@@ -1224,7 +1227,7 @@ class NetworkTrainer:
if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs})
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = self.generate_step_logs(
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm
)
@@ -1233,7 +1236,7 @@ class NetworkTrainer:
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)