sample images in training (not fully tested)

This commit is contained in:
Kohya S
2023-02-27 17:48:32 +09:00
parent a28f9ae7a3
commit dd523c94ff
5 changed files with 209 additions and 6 deletions

View File

@@ -1,4 +1,3 @@
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP
import importlib
import argparse
@@ -12,7 +11,6 @@ import json
from tqdm import tqdm
import torch
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
import library.train_util as train_util
@@ -400,6 +398,8 @@ def train(args):
progress_bar.update(1)
global_step += 1
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
@@ -445,6 +445,8 @@ def train(args):
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# end of epoch
metadata["ss_epoch"] = str(num_train_epochs)