mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 16:22:28 +00:00
feat: update lumina system prompt handling
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,3 +6,4 @@ venv
|
||||
build
|
||||
.vscode
|
||||
wandb
|
||||
MagicMock
|
||||
@@ -75,7 +75,6 @@ class BaseSubsetParams:
|
||||
custom_attributes: Optional[Dict[str, Any]] = None
|
||||
validation_seed: int = 0
|
||||
validation_split: float = 0.0
|
||||
system_prompt: Optional[str] = None
|
||||
resize_interpolation: Optional[str] = None
|
||||
|
||||
|
||||
@@ -108,7 +107,6 @@ class BaseDatasetParams:
|
||||
debug_dataset: bool = False
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
system_prompt: Optional[str] = None
|
||||
resize_interpolation: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
@@ -199,7 +197,6 @@ class ConfigSanitizer:
|
||||
"caption_prefix": str,
|
||||
"caption_suffix": str,
|
||||
"custom_attributes": dict,
|
||||
"system_prompt": str,
|
||||
"resize_interpolation": str,
|
||||
}
|
||||
# DO means DropOut
|
||||
@@ -246,7 +243,6 @@ class ConfigSanitizer:
|
||||
"validation_split": float,
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
"network_multiplier": float,
|
||||
"system_prompt": str,
|
||||
"resize_interpolation": str,
|
||||
}
|
||||
|
||||
@@ -534,7 +530,6 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
resize_interpolation: {dataset.resize_interpolation}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
system_prompt: {dataset.system_prompt}
|
||||
""")
|
||||
|
||||
if dataset.enable_bucket:
|
||||
@@ -569,7 +564,6 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
alpha_mask: {subset.alpha_mask}
|
||||
resize_interpolation: {subset.resize_interpolation}
|
||||
custom_attributes: {subset.custom_attributes}
|
||||
system_prompt: {subset.system_prompt}
|
||||
"""), " ")
|
||||
|
||||
if is_dreambooth:
|
||||
|
||||
@@ -218,8 +218,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy)
|
||||
assert isinstance(tokenize_strategy, LuminaTokenizeStrategy)
|
||||
|
||||
system_prompt_special_token = "<Prompt Start>"
|
||||
captions = [f"{info.system_prompt} {system_prompt_special_token} " if info.system_prompt else "" + info.caption for info in batch]
|
||||
captions = [info.caption for info in batch]
|
||||
|
||||
if self.is_weighted:
|
||||
tokens, attention_masks, weights_list = (
|
||||
|
||||
@@ -266,12 +266,14 @@ def train(args):
|
||||
strategy_base.TextEncodingStrategy.get_strategy()
|
||||
)
|
||||
|
||||
system_prompt_special_token = "<Prompt Start>"
|
||||
system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else ""
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for p in [
|
||||
prompt_dict.get("prompt", ""),
|
||||
system_prompt + prompt_dict.get("prompt", ""),
|
||||
prompt_dict.get("negative_prompt", ""),
|
||||
]:
|
||||
if p not in sample_prompts_te_outputs:
|
||||
|
||||
@@ -58,7 +58,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
torch.device("cpu"),
|
||||
disable_mmap=args.disable_mmap_load_safetensors,
|
||||
use_flash_attn=args.use_flash_attn,
|
||||
use_sage_attn=args.use_sage_attn
|
||||
use_sage_attn=args.use_sage_attn,
|
||||
)
|
||||
|
||||
if args.fp8_base:
|
||||
@@ -75,7 +75,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
model.to(torch.float8_e4m3fn)
|
||||
|
||||
if args.blocks_to_swap:
|
||||
logger.info(f'Lumina 2: Enabling block swap: {args.blocks_to_swap}')
|
||||
logger.info(f"Lumina 2: Enabling block swap: {args.blocks_to_swap}")
|
||||
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
self.is_swapping_blocks = True
|
||||
|
||||
@@ -157,13 +157,13 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
|
||||
|
||||
system_prompt_special_token = "<Prompt Start>"
|
||||
system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else ""
|
||||
system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else ""
|
||||
sample_prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in sample_prompts:
|
||||
prompts = [
|
||||
prompt_dict.get("prompt", ""),
|
||||
system_prompt + prompt_dict.get("prompt", ""),
|
||||
prompt_dict.get("negative_prompt", ""),
|
||||
]
|
||||
for i, prompt in enumerate(prompts):
|
||||
@@ -371,7 +371,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
|
||||
@@ -126,13 +126,12 @@ def test_lumina_text_encoder_outputs_caching_strategy():
|
||||
|
||||
# Create a mock class for ImageInfo
|
||||
class MockImageInfo:
|
||||
def __init__(self, caption, system_prompt, cache_path):
|
||||
def __init__(self, caption, cache_path):
|
||||
self.caption = caption
|
||||
self.system_prompt = system_prompt
|
||||
self.text_encoder_outputs_npz = cache_path
|
||||
|
||||
# Create a sample input info
|
||||
image_info = MockImageInfo("Test caption", "", cache_file)
|
||||
image_info = MockImageInfo("Test caption", cache_file)
|
||||
|
||||
# Simulate a batch
|
||||
batch = [image_info]
|
||||
|
||||
Reference in New Issue
Block a user