mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
sample images in training (not fully tested)
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import importlib
|
||||
import argparse
|
||||
@@ -12,7 +11,6 @@ import json
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
@@ -400,6 +398,8 @@ def train(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if epoch == 0:
|
||||
loss_list.append(current_loss)
|
||||
@@ -445,6 +445,8 @@ def train(args):
|
||||
if saving and args.save_state:
|
||||
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# end of epoch
|
||||
|
||||
metadata["ss_epoch"] = str(num_train_epochs)
|
||||
|
||||
Reference in New Issue
Block a user