mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add feature to sample images during sdxl training
This commit is contained in:
@@ -290,8 +290,9 @@ def train(args):
|
||||
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataloader, None
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
text_encoder1.to("cpu")
|
||||
text_encoder2.to("cpu")
|
||||
# Text Encoder doesn't work on CPU with fp16
|
||||
text_encoder1.to("cpu", dtype=torch.float32)
|
||||
text_encoder2.to("cpu", dtype=torch.float32)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
@@ -467,19 +468,17 @@ def train(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
# sdxl_train_util.sample_images(
|
||||
# accelerator,
|
||||
# args,
|
||||
# None,
|
||||
# global_step,
|
||||
# accelerator.device,
|
||||
# vae,
|
||||
# tokenizer1,
|
||||
# tokenizer2,
|
||||
# text_encoder1,
|
||||
# text_encoder2,
|
||||
# unet,
|
||||
# )
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
None,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
[tokenizer1, tokenizer2],
|
||||
[text_encoder1, text_encoder2],
|
||||
unet,
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
@@ -553,7 +552,17 @@ def train(args):
|
||||
ckpt_info,
|
||||
)
|
||||
|
||||
# train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch + 1,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
[tokenizer1, tokenizer2],
|
||||
[text_encoder1, text_encoder2],
|
||||
unet,
|
||||
)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
# if is_main_process:
|
||||
|
||||
Reference in New Issue
Block a user