call optimizer eval/train for sample_at_first, also set train after resuming closes #1667

This commit is contained in:
Kohya S
2024-10-04 20:35:16 +09:00
parent 8bea039a8d
commit ba08a89894
2 changed files with 4 additions and 0 deletions

View File

@@ -706,7 +706,9 @@ def train(args):
accelerator.unwrap_model(flux).prepare_block_swap_before_forward() accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
# For --sample_at_first # For --sample_at_first
optimizer_eval_fn()
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
optimizer_train_fn()
if len(accelerator.trackers) > 0: if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb # log empty object to commit the sample images to wandb
accelerator.log({}, step=0) accelerator.log({}, step=0)

View File

@@ -1042,7 +1042,9 @@ class NetworkTrainer:
text_encoder = None text_encoder = None
# For --sample_at_first # For --sample_at_first
optimizer_eval_fn()
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
optimizer_train_fn()
if len(accelerator.trackers) > 0: if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb # log empty object to commit the sample images to wandb
accelerator.log({}, step=0) accelerator.log({}, step=0)