sample generation in SDXL ControlNet training

This commit is contained in:
Kohya S
2024-09-30 23:39:32 +09:00
parent d78f6a775c
commit 793999d116
5 changed files with 322 additions and 165 deletions

View File

@@ -83,6 +83,7 @@ def train(args):
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
tokenizer1, tokenizer2 = tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2 # this is used for sampling images
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
@@ -436,19 +437,19 @@ def train(args):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# # For --sample_at_first
# sdxl_train_util.sample_images(
# accelerator,
# args,
# 0,
# global_step,
# accelerator.device,
# vae,
# [tokenizer1, tokenizer2],
# [text_encoder1, text_encoder2],
# unet,
# controlnet=control_net,
# )
# For --sample_at_first
sdxl_train_util.sample_images(
accelerator,
args,
0,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
unet,
controlnet=control_net,
)
# training loop
for epoch in range(num_train_epochs):
@@ -484,7 +485,7 @@ def train(args):
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2]
tokenize_strategy, [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], [input_ids1, input_ids2]
)
if args.full_fp16:
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
@@ -558,18 +559,18 @@ def train(args):
progress_bar.update(1)
global_step += 1
# sdxl_train_util.sample_images(
# accelerator,
# args,
# None,
# global_step,
# accelerator.device,
# vae,
# [tokenizer1, tokenizer2],
# [text_encoder1, text_encoder2],
# unet,
# controlnet=control_net,
# )
sdxl_train_util.sample_images(
accelerator,
args,
None,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
unet,
controlnet=control_net,
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
@@ -628,7 +629,7 @@ def train(args):
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
unet,
controlnet=control_net,
)