mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Merge pull request #20 from rockerBOO/lumina-system-prompt-special-token
Lumina system prompt special token
This commit is contained in:
@@ -330,11 +330,12 @@ def sample_image_inference(
|
||||
logger.info(f"renorm: {renorm_cfg}")
|
||||
# logger.info(f"sample_sampler: {sampler_name}")
|
||||
|
||||
system_prompt = args.system_prompt or ""
|
||||
system_prompt_special_token = "<Prompt Start>"
|
||||
system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else ""
|
||||
|
||||
# Apply system prompt to prompts
|
||||
prompt = system_prompt + prompt
|
||||
negative_prompt = system_prompt + negative_prompt
|
||||
negative_prompt = negative_prompt
|
||||
|
||||
# Get sample prompts from cache
|
||||
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
|
||||
|
||||
@@ -217,7 +217,8 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy)
|
||||
assert isinstance(tokenize_strategy, LuminaTokenizeStrategy)
|
||||
|
||||
captions = [info.system_prompt or "" + info.caption for info in batch]
|
||||
system_prompt_special_token = "<Prompt Start>"
|
||||
captions = [f"{info.system_prompt} {system_prompt_special_token} " if info.system_prompt else "" + info.caption for info in batch]
|
||||
|
||||
if self.is_weighted:
|
||||
tokens, attention_masks, weights_list = (
|
||||
|
||||
@@ -1692,7 +1692,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
text_encoder_outputs_list.append(text_encoder_outputs)
|
||||
|
||||
if tokenization_required:
|
||||
system_prompt = subset.system_prompt or ""
|
||||
system_prompt_special_token = "<Prompt Start>"
|
||||
system_prompt = f"{subset.system_prompt} {system_prompt_special_token} " if subset.system_prompt else ""
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension
|
||||
# if self.XTI_layers:
|
||||
@@ -2091,7 +2092,8 @@ class DreamBoothDataset(BaseDataset):
|
||||
else:
|
||||
num_train_images += num_repeats * len(img_paths)
|
||||
|
||||
system_prompt = self.system_prompt or subset.system_prompt or ""
|
||||
system_prompt_special_token = "<Prompt Start>"
|
||||
system_prompt = f"{self.system_prompt or subset.system_prompt} {system_prompt_special_token} " if self.system_prompt or subset.system_prompt else ""
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, num_repeats, system_prompt + caption, subset.is_reg, img_path)
|
||||
if size is not None:
|
||||
|
||||
@@ -155,7 +155,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
|
||||
assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
|
||||
|
||||
system_prompt = args.system_prompt or ""
|
||||
system_prompt_special_token = "<Prompt Start>"
|
||||
system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else ""
|
||||
sample_prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
@@ -164,8 +165,10 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
prompt_dict.get("prompt", ""),
|
||||
prompt_dict.get("negative_prompt", ""),
|
||||
]
|
||||
for prompt in prompts:
|
||||
prompt = system_prompt + prompt
|
||||
for i, prompt in enumerate(prompts):
|
||||
# Add system prompt only to positive prompt
|
||||
if i == 0:
|
||||
prompt = system_prompt + prompt
|
||||
if prompt in sample_prompts_te_outputs:
|
||||
continue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user