diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 14a79bb2..45f22bc4 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -249,7 +249,7 @@ def sample_image_inference( accelerator: Accelerator, args: argparse.Namespace, nextdit: lumina_models.NextDiT, - gemma2_model: Gemma2Model, + gemma2_model: list[Gemma2Model], vae: AutoEncoder, save_dir: str, prompt_dicts: list[Dict[str, str]], @@ -266,7 +266,7 @@ def sample_image_inference( accelerator (Accelerator): Accelerator object args (argparse.Namespace): Arguments object nextdit (lumina_models.NextDiT): NextDiT model - gemma2_model (Gemma2Model): Gemma2 model + gemma2_model (list[Gemma2Model]): Gemma2 model vae (AutoEncoder): VAE model save_dir (str): Directory to save images prompt_dict (Dict[str, str]): Prompt dictionary @@ -330,12 +330,8 @@ def sample_image_inference( logger.info(f"renorm: {renorm_cfg}") # logger.info(f"sample_sampler: {sampler_name}") - system_prompt_special_token = "" - system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" - # Apply system prompt to prompts - prompt = system_prompt + prompt - negative_prompt = negative_prompt + # No need to add system prompt here, as it has been handled in the tokenize_strategy # Get sample prompts from cache if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: @@ -355,12 +351,12 @@ def sample_image_inference( if gemma2_model is not None: tokens_and_masks = tokenize_strategy.tokenize(prompt) gemma2_conds = encoding_strategy.encode_tokens( - tokenize_strategy, [gemma2_model], tokens_and_masks + tokenize_strategy, gemma2_model, tokens_and_masks ) - tokens_and_masks = tokenize_strategy.tokenize(negative_prompt) + tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True) neg_gemma2_conds = encoding_strategy.encode_tokens( - tokenize_strategy, [gemma2_model], tokens_and_masks + tokenize_strategy, gemma2_model, tokens_and_masks ) # Unpack Gemma2 outputs