fix: sample generation with system prompt, without TE output caching

This commit is contained in:
Kohya S
2025-07-09 21:55:36 +09:00
parent 2fffcb605c
commit b4d1152293

View File

@@ -249,7 +249,7 @@ def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
nextdit: lumina_models.NextDiT,
gemma2_model: Gemma2Model,
gemma2_model: list[Gemma2Model],
vae: AutoEncoder,
save_dir: str,
prompt_dicts: list[Dict[str, str]],
@@ -266,7 +266,7 @@ def sample_image_inference(
accelerator (Accelerator): Accelerator object
args (argparse.Namespace): Arguments object
nextdit (lumina_models.NextDiT): NextDiT model
gemma2_model (Gemma2Model): Gemma2 model
gemma2_model (list[Gemma2Model]): Gemma2 model
vae (AutoEncoder): VAE model
save_dir (str): Directory to save images
prompt_dict (Dict[str, str]): Prompt dictionary
@@ -330,12 +330,8 @@ def sample_image_inference(
logger.info(f"renorm: {renorm_cfg}")
# logger.info(f"sample_sampler: {sampler_name}")
system_prompt_special_token = "<Prompt Start>"
system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else ""
# Apply system prompt to prompts
prompt = system_prompt + prompt
negative_prompt = negative_prompt
# No need to add system prompt here, as it has been handled in the tokenize_strategy
# Get sample prompts from cache
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
@@ -355,12 +351,12 @@ def sample_image_inference(
if gemma2_model is not None:
tokens_and_masks = tokenize_strategy.tokenize(prompt)
gemma2_conds = encoding_strategy.encode_tokens(
tokenize_strategy, [gemma2_model], tokens_and_masks
tokenize_strategy, gemma2_model, tokens_and_masks
)
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt)
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True)
neg_gemma2_conds = encoding_strategy.encode_tokens(
tokenize_strategy, [gemma2_model], tokens_and_masks
tokenize_strategy, gemma2_model, tokens_and_masks
)
# Unpack Gemma2 outputs