sample images in training (not fully tested)

This commit is contained in:
Kohya S
2023-02-27 17:48:32 +09:00
parent a28f9ae7a3
commit dd523c94ff
5 changed files with 209 additions and 6 deletions

View File

@@ -278,6 +278,8 @@ def train(args):
progress_bar.update(1)
global_step += 1
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
@@ -309,6 +311,8 @@ def train(args):
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
is_main_process = accelerator.is_main_process
if is_main_process:
unet = unwrap_model(unet)