sample images for training

This commit is contained in:
Kohya S
2024-07-29 23:18:34 +09:00
parent 1a977e847a
commit 002d75179a
2 changed files with 367 additions and 32 deletions

View File

@@ -299,6 +299,7 @@ def train(args):
t5xxl.eval()
# cache text encoder outputs
sample_prompts_te_outputs = None
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad here
clip_l.to(accelerator.device)
@@ -321,6 +322,22 @@ def train(args):
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process)
# cache sample prompt's embeddings to free text encoder's memory
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
prompts = sd3_train_utils.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_list = sd3_tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_list
)
accelerator.wait_for_everyone()
# load MMDIT
@@ -635,10 +652,8 @@ def train(args):
init_kwargs=init_kwargs,
)
# # For --sample_at_first
# sd3_train_utils.sample_images(
# accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [clip_l, clip_g], mmdit
# )
# For --sample_at_first
sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs)
# following function will be moved to sd3_train_utils
@@ -831,17 +846,9 @@ def train(args):
progress_bar.update(1)
global_step += 1
# sdxl_train_util.sample_images(
# accelerator,
# args,
# None,
# global_step,
# accelerator.device,
# vae,
# [tokenizer1, tokenizer2],
# [clip_l, clip_g],
# mmdit,
# )
sd3_train_utils.sample_images(
accelerator, args, None, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
@@ -900,17 +907,9 @@ def train(args):
vae,
)
# sdxl_train_util.sample_images(
# accelerator,
# args,
# epoch + 1,
# global_step,
# accelerator.device,
# vae,
# [tokenizer1, tokenizer2],
# [clip_l, clip_g],
# mmdit,
# )
sd3_train_utils.sample_images(
accelerator, args, epoch + 1, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs
)
is_main_process = accelerator.is_main_process
# if is_main_process: