mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev
This commit is contained in:
@@ -1423,5 +1423,17 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
|||||||
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||||
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
||||||
|
|
||||||
|
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||||
|
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||||
|
|
||||||
|
if args.network_train_unet_only:
|
||||||
|
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
|
||||||
|
elif args.network_train_text_encoder_only:
|
||||||
|
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
||||||
|
else:
|
||||||
|
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
||||||
|
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1]
|
||||||
|
|
||||||
|
return logs
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|||||||
@@ -347,20 +347,21 @@ def train(args):
|
|||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
if args.logging_dir is not None:
|
|
||||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
|
||||||
accelerator.log(logs, step=global_step)
|
|
||||||
|
|
||||||
loss_total += current_loss
|
loss_total += current_loss
|
||||||
avr_loss = loss_total / (step+1)
|
avr_loss = loss_total / (step+1)
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
|
if args.logging_dir is not None:
|
||||||
|
logs = train_util.generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
||||||
|
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"epoch_loss": loss_total / len(train_dataloader)}
|
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
||||||
accelerator.log(logs, step=epoch+1)
|
accelerator.log(logs, step=epoch+1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|||||||
Reference in New Issue
Block a user