mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
sample generation in SDXL ControlNet training
This commit is contained in:
@@ -83,6 +83,7 @@ def train(args):
|
||||
|
||||
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
||||
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||
tokenizer1, tokenizer2 = tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2 # this is used for sampling images
|
||||
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
@@ -436,19 +437,19 @@ def train(args):
|
||||
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# # For --sample_at_first
|
||||
# sdxl_train_util.sample_images(
|
||||
# accelerator,
|
||||
# args,
|
||||
# 0,
|
||||
# global_step,
|
||||
# accelerator.device,
|
||||
# vae,
|
||||
# [tokenizer1, tokenizer2],
|
||||
# [text_encoder1, text_encoder2],
|
||||
# unet,
|
||||
# controlnet=control_net,
|
||||
# )
|
||||
# For --sample_at_first
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
0,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
[tokenizer1, tokenizer2],
|
||||
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
|
||||
unet,
|
||||
controlnet=control_net,
|
||||
)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
@@ -484,7 +485,7 @@ def train(args):
|
||||
input_ids1 = input_ids1.to(accelerator.device)
|
||||
input_ids2 = input_ids2.to(accelerator.device)
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2]
|
||||
tokenize_strategy, [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], [input_ids1, input_ids2]
|
||||
)
|
||||
if args.full_fp16:
|
||||
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
|
||||
@@ -558,18 +559,18 @@ def train(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
# sdxl_train_util.sample_images(
|
||||
# accelerator,
|
||||
# args,
|
||||
# None,
|
||||
# global_step,
|
||||
# accelerator.device,
|
||||
# vae,
|
||||
# [tokenizer1, tokenizer2],
|
||||
# [text_encoder1, text_encoder2],
|
||||
# unet,
|
||||
# controlnet=control_net,
|
||||
# )
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
None,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
[tokenizer1, tokenizer2],
|
||||
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
|
||||
unet,
|
||||
controlnet=control_net,
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
@@ -628,7 +629,7 @@ def train(args):
|
||||
accelerator.device,
|
||||
vae,
|
||||
[tokenizer1, tokenizer2],
|
||||
[text_encoder1, text_encoder2],
|
||||
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
|
||||
unet,
|
||||
controlnet=control_net,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user