diff --git a/library/train_util.py b/library/train_util.py index 0fdbadc1..ffa099ef 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1423,5 +1423,17 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator): model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = {"loss/current": current_loss, "loss/average": avr_loss} + + if args.network_train_unet_only: + logs["lr/unet"] = lr_scheduler.get_last_lr()[0] + elif args.network_train_text_encoder_only: + logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] + else: + logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] + logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] + + return logs # endregion diff --git a/train_network.py b/train_network.py index 8b4e008b..bd45d980 100644 --- a/train_network.py +++ b/train_network.py @@ -347,20 +347,21 @@ def train(args): global_step += 1 current_loss = loss.detach().item() - if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} - accelerator.log(logs, step=global_step) - loss_total += current_loss avr_loss = loss_total / (step+1) logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) + if args.logging_dir is not None: + logs = train_util.generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"epoch_loss": loss_total / len(train_dataloader)} + logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch+1) accelerator.wait_for_everyone()