From 22ee0ac467ffc914fe174b7006308d6cbf7a6f63 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 21 Jan 2023 12:51:17 +0900 Subject: [PATCH] Move TE/UN loss calc to train script --- library/train_util.py | 12 ------------ train_network.py | 17 +++++++++++++++-- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index ffa099ef..0fdbadc1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1423,17 +1423,5 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator): model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) -def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): - logs = {"loss/current": current_loss, "loss/average": avr_loss} - - if args.network_train_unet_only: - logs["lr/unet"] = lr_scheduler.get_last_lr()[0] - elif args.network_train_text_encoder_only: - logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] - else: - logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] - logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] - - return logs # endregion diff --git a/train_network.py b/train_network.py index bd45d980..70db4450 100644 --- a/train_network.py +++ b/train_network.py @@ -21,6 +21,20 @@ def collate_fn(examples): return examples[0] +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = {"loss/current": current_loss, "loss/average": avr_loss} + + if args.network_train_unet_only: + logs["lr/unet"] = lr_scheduler.get_last_lr()[0] + elif args.network_train_text_encoder_only: + logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] + else: + logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] + logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder + + return logs + + def train(args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -353,8 +367,7 @@ def train(args): progress_bar.set_postfix(**logs) if args.logging_dir is not None: - logs = train_util.generate_step_logs(args, current_loss, avr_loss, lr_scheduler) - + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: