mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Improve wandb logging (#1576)
* fix: wrong training steps were recorded to wandb, and no log was sent when logging_dir was not specified * fix: checking of whether wandb is enabled * feat: log images to wandb with their positive prompt as captions * feat: logging sample images' caption for sd3 and flux * fix: import wandb before use
This commit is contained in:
@@ -1038,6 +1038,9 @@ class NetworkTrainer:
|
||||
|
||||
# For --sample_at_first
|
||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||
if len(accelerator.trackers) > 0:
|
||||
# log empty object to commit the sample images to wandb
|
||||
accelerator.log({}, step=0)
|
||||
|
||||
# training loop
|
||||
if initial_step > 0: # only if skip_until_initial_step is specified
|
||||
@@ -1224,7 +1227,7 @@ class NetworkTrainer:
|
||||
if args.scale_weight_norms:
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
|
||||
if args.logging_dir is not None:
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = self.generate_step_logs(
|
||||
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm
|
||||
)
|
||||
@@ -1233,7 +1236,7 @@ class NetworkTrainer:
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user