diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 22c9a0b3..012922ec 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -330,11 +330,12 @@ def sample_image_inference( logger.info(f"renorm: {renorm_cfg}") # logger.info(f"sample_sampler: {sampler_name}") - system_prompt = args.system_prompt or "" + system_prompt_special_token = "" + system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" # Apply system prompt to prompts prompt = system_prompt + prompt - negative_prompt = system_prompt + negative_prompt + negative_prompt = negative_prompt # Get sample prompts from cache if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index 1d149ceb..ee4180d8 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -217,7 +217,8 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy) assert isinstance(tokenize_strategy, LuminaTokenizeStrategy) - captions = [info.system_prompt or "" + info.caption for info in batch] + system_prompt_special_token = "" + captions = [f"{info.system_prompt} {system_prompt_special_token} " if info.system_prompt else "" + info.caption for info in batch] if self.is_weighted: tokens, attention_masks, weights_list = ( diff --git a/library/train_util.py b/library/train_util.py index 0c057bd1..34b98f89 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1692,7 +1692,8 @@ class BaseDataset(torch.utils.data.Dataset): text_encoder_outputs_list.append(text_encoder_outputs) if tokenization_required: - system_prompt = subset.system_prompt or "" + system_prompt_special_token = "" + system_prompt = f"{subset.system_prompt} {system_prompt_special_token} " if subset.system_prompt else "" caption = self.process_caption(subset, image_info.caption) input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension # if self.XTI_layers: @@ -2091,7 +2092,8 @@ class DreamBoothDataset(BaseDataset): else: num_train_images += num_repeats * len(img_paths) - system_prompt = self.system_prompt or subset.system_prompt or "" + system_prompt_special_token = "" + system_prompt = f"{self.system_prompt or subset.system_prompt} {system_prompt_special_token} " if self.system_prompt or subset.system_prompt else "" for img_path, caption, size in zip(img_paths, captions, sizes): info = ImageInfo(img_path, num_repeats, system_prompt + caption, subset.is_reg, img_path) if size is not None: diff --git a/lumina_train_network.py b/lumina_train_network.py index 60c39c20..ab811bd5 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -155,7 +155,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) - system_prompt = args.system_prompt or "" + system_prompt_special_token = "" + system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" sample_prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): @@ -164,8 +165,10 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", ""), ] - for prompt in prompts: - prompt = system_prompt + prompt + for i, prompt in enumerate(prompts): + # Add system prompt only to positive prompt + if i == 0: + prompt = system_prompt + prompt if prompt in sample_prompts_te_outputs: continue