mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
sample images for training
This commit is contained in:
51
sd3_train.py
51
sd3_train.py
@@ -299,6 +299,7 @@ def train(args):
|
||||
t5xxl.eval()
|
||||
|
||||
# cache text encoder outputs
|
||||
sample_prompts_te_outputs = None
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad here
|
||||
clip_l.to(accelerator.device)
|
||||
@@ -321,6 +322,22 @@ def train(args):
|
||||
|
||||
with accelerator.autocast():
|
||||
train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process)
|
||||
|
||||
# cache sample prompt's embeddings to free text encoder's memory
|
||||
if args.sample_prompts is not None:
|
||||
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
|
||||
prompts = sd3_train_utils.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
||||
if p not in sample_prompts_te_outputs:
|
||||
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
||||
tokens_list = sd3_tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
sd3_tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_list
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# load MMDIT
|
||||
@@ -635,10 +652,8 @@ def train(args):
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
# # For --sample_at_first
|
||||
# sd3_train_utils.sample_images(
|
||||
# accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [clip_l, clip_g], mmdit
|
||||
# )
|
||||
# For --sample_at_first
|
||||
sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs)
|
||||
|
||||
# following function will be moved to sd3_train_utils
|
||||
|
||||
@@ -831,17 +846,9 @@ def train(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
# sdxl_train_util.sample_images(
|
||||
# accelerator,
|
||||
# args,
|
||||
# None,
|
||||
# global_step,
|
||||
# accelerator.device,
|
||||
# vae,
|
||||
# [tokenizer1, tokenizer2],
|
||||
# [clip_l, clip_g],
|
||||
# mmdit,
|
||||
# )
|
||||
sd3_train_utils.sample_images(
|
||||
accelerator, args, None, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
@@ -900,17 +907,9 @@ def train(args):
|
||||
vae,
|
||||
)
|
||||
|
||||
# sdxl_train_util.sample_images(
|
||||
# accelerator,
|
||||
# args,
|
||||
# epoch + 1,
|
||||
# global_step,
|
||||
# accelerator.device,
|
||||
# vae,
|
||||
# [tokenizer1, tokenizer2],
|
||||
# [clip_l, clip_g],
|
||||
# mmdit,
|
||||
# )
|
||||
sd3_train_utils.sample_images(
|
||||
accelerator, args, epoch + 1, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs
|
||||
)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
# if is_main_process:
|
||||
|
||||
Reference in New Issue
Block a user