add feature to sample images during sdxl training

This commit is contained in:
Kohya S
2023-07-02 16:42:19 +09:00
parent 227a62e4c4
commit 64cf922841
5 changed files with 1402 additions and 32 deletions

View File

@@ -290,8 +290,9 @@ def train(args):
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataloader, None
)
accelerator.wait_for_everyone()
text_encoder1.to("cpu")
text_encoder2.to("cpu")
# Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
@@ -467,19 +468,17 @@ def train(args):
progress_bar.update(1)
global_step += 1
# sdxl_train_util.sample_images(
# accelerator,
# args,
# None,
# global_step,
# accelerator.device,
# vae,
# tokenizer1,
# tokenizer2,
# text_encoder1,
# text_encoder2,
# unet,
# )
sdxl_train_util.sample_images(
accelerator,
args,
None,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
@@ -553,7 +552,17 @@ def train(args):
ckpt_info,
)
# train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
sdxl_train_util.sample_images(
accelerator,
args,
epoch + 1,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
)
is_main_process = accelerator.is_main_process
# if is_main_process: