mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix to call train/eval in schedulefree #1605
This commit is contained in:
@@ -498,6 +498,7 @@ class NetworkTrainer:
|
||||
# accelerator.print(f"trainable_params: {k} = {v}")
|
||||
|
||||
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
|
||||
# prepare dataloader
|
||||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||
@@ -1199,6 +1200,7 @@ class NetworkTrainer:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
optimizer_eval_fn()
|
||||
self.sample_images(
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
|
||||
)
|
||||
@@ -1217,6 +1219,7 @@ class NetworkTrainer:
|
||||
if remove_step_no is not None:
|
||||
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
||||
remove_model(remove_ckpt_name)
|
||||
optimizer_train_fn()
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
@@ -1243,6 +1246,7 @@ class NetworkTrainer:
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 指定エポックごとにモデルを保存
|
||||
optimizer_eval_fn()
|
||||
if args.save_every_n_epochs is not None:
|
||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||
if is_main_process and saving:
|
||||
@@ -1258,6 +1262,7 @@ class NetworkTrainer:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||
optimizer_train_fn()
|
||||
|
||||
# end of epoch
|
||||
|
||||
@@ -1268,6 +1273,7 @@ class NetworkTrainer:
|
||||
network = accelerator.unwrap_model(network)
|
||||
|
||||
accelerator.end_training()
|
||||
optimizer_eval_fn()
|
||||
|
||||
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
Reference in New Issue
Block a user