Merge pull request #20 from rockerBOO/lumina-system-prompt-special-token

Lumina system prompt special token
This commit is contained in:
青龍聖者@bdsqlsz
2025-03-02 18:30:49 +08:00
committed by GitHub
4 changed files with 15 additions and 8 deletions

View File

@@ -330,11 +330,12 @@ def sample_image_inference(
logger.info(f"renorm: {renorm_cfg}")
# logger.info(f"sample_sampler: {sampler_name}")
system_prompt = args.system_prompt or ""
system_prompt_special_token = "<Prompt Start>"
system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else ""
# Apply system prompt to prompts
prompt = system_prompt + prompt
negative_prompt = system_prompt + negative_prompt
negative_prompt = negative_prompt
# Get sample prompts from cache
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:

View File

@@ -217,7 +217,8 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy)
assert isinstance(tokenize_strategy, LuminaTokenizeStrategy)
captions = [info.system_prompt or "" + info.caption for info in batch]
system_prompt_special_token = "<Prompt Start>"
captions = [f"{info.system_prompt} {system_prompt_special_token} " if info.system_prompt else "" + info.caption for info in batch]
if self.is_weighted:
tokens, attention_masks, weights_list = (

View File

@@ -1692,7 +1692,8 @@ class BaseDataset(torch.utils.data.Dataset):
text_encoder_outputs_list.append(text_encoder_outputs)
if tokenization_required:
system_prompt = subset.system_prompt or ""
system_prompt_special_token = "<Prompt Start>"
system_prompt = f"{subset.system_prompt} {system_prompt_special_token} " if subset.system_prompt else ""
caption = self.process_caption(subset, image_info.caption)
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension
# if self.XTI_layers:
@@ -2091,7 +2092,8 @@ class DreamBoothDataset(BaseDataset):
else:
num_train_images += num_repeats * len(img_paths)
system_prompt = self.system_prompt or subset.system_prompt or ""
system_prompt_special_token = "<Prompt Start>"
system_prompt = f"{self.system_prompt or subset.system_prompt} {system_prompt_special_token} " if self.system_prompt or subset.system_prompt else ""
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, num_repeats, system_prompt + caption, subset.is_reg, img_path)
if size is not None:

View File

@@ -155,7 +155,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
system_prompt = args.system_prompt or ""
system_prompt_special_token = "<Prompt Start>"
system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else ""
sample_prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
@@ -164,8 +165,10 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
prompt_dict.get("prompt", ""),
prompt_dict.get("negative_prompt", ""),
]
for prompt in prompts:
prompt = system_prompt + prompt
for i, prompt in enumerate(prompts):
# Add system prompt only to positive prompt
if i == 0:
prompt = system_prompt + prompt
if prompt in sample_prompts_te_outputs:
continue