mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
fix: sample generation with system prompt, without TE output caching
This commit is contained in:
@@ -249,7 +249,7 @@ def sample_image_inference(
|
|||||||
accelerator: Accelerator,
|
accelerator: Accelerator,
|
||||||
args: argparse.Namespace,
|
args: argparse.Namespace,
|
||||||
nextdit: lumina_models.NextDiT,
|
nextdit: lumina_models.NextDiT,
|
||||||
gemma2_model: Gemma2Model,
|
gemma2_model: list[Gemma2Model],
|
||||||
vae: AutoEncoder,
|
vae: AutoEncoder,
|
||||||
save_dir: str,
|
save_dir: str,
|
||||||
prompt_dicts: list[Dict[str, str]],
|
prompt_dicts: list[Dict[str, str]],
|
||||||
@@ -266,7 +266,7 @@ def sample_image_inference(
|
|||||||
accelerator (Accelerator): Accelerator object
|
accelerator (Accelerator): Accelerator object
|
||||||
args (argparse.Namespace): Arguments object
|
args (argparse.Namespace): Arguments object
|
||||||
nextdit (lumina_models.NextDiT): NextDiT model
|
nextdit (lumina_models.NextDiT): NextDiT model
|
||||||
gemma2_model (Gemma2Model): Gemma2 model
|
gemma2_model (list[Gemma2Model]): Gemma2 model
|
||||||
vae (AutoEncoder): VAE model
|
vae (AutoEncoder): VAE model
|
||||||
save_dir (str): Directory to save images
|
save_dir (str): Directory to save images
|
||||||
prompt_dict (Dict[str, str]): Prompt dictionary
|
prompt_dict (Dict[str, str]): Prompt dictionary
|
||||||
@@ -330,12 +330,8 @@ def sample_image_inference(
|
|||||||
logger.info(f"renorm: {renorm_cfg}")
|
logger.info(f"renorm: {renorm_cfg}")
|
||||||
# logger.info(f"sample_sampler: {sampler_name}")
|
# logger.info(f"sample_sampler: {sampler_name}")
|
||||||
|
|
||||||
system_prompt_special_token = "<Prompt Start>"
|
|
||||||
system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else ""
|
|
||||||
|
|
||||||
# Apply system prompt to prompts
|
# No need to add system prompt here, as it has been handled in the tokenize_strategy
|
||||||
prompt = system_prompt + prompt
|
|
||||||
negative_prompt = negative_prompt
|
|
||||||
|
|
||||||
# Get sample prompts from cache
|
# Get sample prompts from cache
|
||||||
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
|
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
|
||||||
@@ -355,12 +351,12 @@ def sample_image_inference(
|
|||||||
if gemma2_model is not None:
|
if gemma2_model is not None:
|
||||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||||
gemma2_conds = encoding_strategy.encode_tokens(
|
gemma2_conds = encoding_strategy.encode_tokens(
|
||||||
tokenize_strategy, [gemma2_model], tokens_and_masks
|
tokenize_strategy, gemma2_model, tokens_and_masks
|
||||||
)
|
)
|
||||||
|
|
||||||
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt)
|
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True)
|
||||||
neg_gemma2_conds = encoding_strategy.encode_tokens(
|
neg_gemma2_conds = encoding_strategy.encode_tokens(
|
||||||
tokenize_strategy, [gemma2_model], tokens_and_masks
|
tokenize_strategy, gemma2_model, tokens_and_masks
|
||||||
)
|
)
|
||||||
|
|
||||||
# Unpack Gemma2 outputs
|
# Unpack Gemma2 outputs
|
||||||
|
|||||||
Reference in New Issue
Block a user