feat: update lumina system prompt handling

This commit is contained in:
Kohya S
2025-06-29 21:33:09 +09:00
parent 52d13373c0
commit 935e0037dc
6 changed files with 11 additions and 17 deletions

1
.gitignore vendored
View File

@@ -6,3 +6,4 @@ venv
build
.vscode
wandb
MagicMock

View File

@@ -75,7 +75,6 @@ class BaseSubsetParams:
custom_attributes: Optional[Dict[str, Any]] = None
validation_seed: int = 0
validation_split: float = 0.0
system_prompt: Optional[str] = None
resize_interpolation: Optional[str] = None
@@ -108,7 +107,6 @@ class BaseDatasetParams:
debug_dataset: bool = False
validation_seed: Optional[int] = None
validation_split: float = 0.0
system_prompt: Optional[str] = None
resize_interpolation: Optional[str] = None
@dataclass
@@ -199,7 +197,6 @@ class ConfigSanitizer:
"caption_prefix": str,
"caption_suffix": str,
"custom_attributes": dict,
"system_prompt": str,
"resize_interpolation": str,
}
# DO means DropOut
@@ -246,7 +243,6 @@ class ConfigSanitizer:
"validation_split": float,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
"system_prompt": str,
"resize_interpolation": str,
}
@@ -534,7 +530,6 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
resolution: {(dataset.width, dataset.height)}
resize_interpolation: {dataset.resize_interpolation}
enable_bucket: {dataset.enable_bucket}
system_prompt: {dataset.system_prompt}
""")
if dataset.enable_bucket:
@@ -569,7 +564,6 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
alpha_mask: {subset.alpha_mask}
resize_interpolation: {subset.resize_interpolation}
custom_attributes: {subset.custom_attributes}
system_prompt: {subset.system_prompt}
"""), " ")
if is_dreambooth:

View File

@@ -218,8 +218,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy)
assert isinstance(tokenize_strategy, LuminaTokenizeStrategy)
system_prompt_special_token = "<Prompt Start>"
captions = [f"{info.system_prompt} {system_prompt_special_token} " if info.system_prompt else "" + info.caption for info in batch]
captions = [info.caption for info in batch]
if self.is_weighted:
tokens, attention_masks, weights_list = (

View File

@@ -266,12 +266,14 @@ def train(args):
strategy_base.TextEncodingStrategy.get_strategy()
)
system_prompt_special_token = "<Prompt Start>"
system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else ""
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [
prompt_dict.get("prompt", ""),
system_prompt + prompt_dict.get("prompt", ""),
prompt_dict.get("negative_prompt", ""),
]:
if p not in sample_prompts_te_outputs:

View File

@@ -58,7 +58,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
torch.device("cpu"),
disable_mmap=args.disable_mmap_load_safetensors,
use_flash_attn=args.use_flash_attn,
use_sage_attn=args.use_sage_attn
use_sage_attn=args.use_sage_attn,
)
if args.fp8_base:
@@ -75,7 +75,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
model.to(torch.float8_e4m3fn)
if args.blocks_to_swap:
logger.info(f'Lumina 2: Enabling block swap: {args.blocks_to_swap}')
logger.info(f"Lumina 2: Enabling block swap: {args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
self.is_swapping_blocks = True
@@ -157,13 +157,13 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
system_prompt_special_token = "<Prompt Start>"
system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else ""
system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else ""
sample_prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in sample_prompts:
prompts = [
prompt_dict.get("prompt", ""),
system_prompt + prompt_dict.get("prompt", ""),
prompt_dict.get("negative_prompt", ""),
]
for i, prompt in enumerate(prompts):
@@ -371,7 +371,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)

View File

@@ -126,13 +126,12 @@ def test_lumina_text_encoder_outputs_caching_strategy():
# Create a mock class for ImageInfo
class MockImageInfo:
def __init__(self, caption, system_prompt, cache_path):
def __init__(self, caption, cache_path):
self.caption = caption
self.system_prompt = system_prompt
self.text_encoder_outputs_npz = cache_path
# Create a sample input info
image_info = MockImageInfo("Test caption", "", cache_file)
image_info = MockImageInfo("Test caption", cache_file)
# Simulate a batch
batch = [image_info]