mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
fix: sample generation with system prompt, without TE output caching
This commit is contained in:
@@ -249,7 +249,7 @@ def sample_image_inference(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
nextdit: lumina_models.NextDiT,
|
||||
gemma2_model: Gemma2Model,
|
||||
gemma2_model: list[Gemma2Model],
|
||||
vae: AutoEncoder,
|
||||
save_dir: str,
|
||||
prompt_dicts: list[Dict[str, str]],
|
||||
@@ -266,7 +266,7 @@ def sample_image_inference(
|
||||
accelerator (Accelerator): Accelerator object
|
||||
args (argparse.Namespace): Arguments object
|
||||
nextdit (lumina_models.NextDiT): NextDiT model
|
||||
gemma2_model (Gemma2Model): Gemma2 model
|
||||
gemma2_model (list[Gemma2Model]): Gemma2 model
|
||||
vae (AutoEncoder): VAE model
|
||||
save_dir (str): Directory to save images
|
||||
prompt_dict (Dict[str, str]): Prompt dictionary
|
||||
@@ -330,12 +330,8 @@ def sample_image_inference(
|
||||
logger.info(f"renorm: {renorm_cfg}")
|
||||
# logger.info(f"sample_sampler: {sampler_name}")
|
||||
|
||||
system_prompt_special_token = "<Prompt Start>"
|
||||
system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else ""
|
||||
|
||||
# Apply system prompt to prompts
|
||||
prompt = system_prompt + prompt
|
||||
negative_prompt = negative_prompt
|
||||
# No need to add system prompt here, as it has been handled in the tokenize_strategy
|
||||
|
||||
# Get sample prompts from cache
|
||||
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
|
||||
@@ -355,12 +351,12 @@ def sample_image_inference(
|
||||
if gemma2_model is not None:
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
gemma2_conds = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [gemma2_model], tokens_and_masks
|
||||
tokenize_strategy, gemma2_model, tokens_and_masks
|
||||
)
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt)
|
||||
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True)
|
||||
neg_gemma2_conds = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [gemma2_model], tokens_and_masks
|
||||
tokenize_strategy, gemma2_model, tokens_and_masks
|
||||
)
|
||||
|
||||
# Unpack Gemma2 outputs
|
||||
|
||||
Reference in New Issue
Block a user