diff --git a/.gitignore b/.gitignore index e492b1ad..4fcf07f6 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ venv build .vscode wandb +MagicMock \ No newline at end of file diff --git a/library/config_util.py b/library/config_util.py index ac726e4f..53727f25 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -75,7 +75,6 @@ class BaseSubsetParams: custom_attributes: Optional[Dict[str, Any]] = None validation_seed: int = 0 validation_split: float = 0.0 - system_prompt: Optional[str] = None resize_interpolation: Optional[str] = None @@ -108,7 +107,6 @@ class BaseDatasetParams: debug_dataset: bool = False validation_seed: Optional[int] = None validation_split: float = 0.0 - system_prompt: Optional[str] = None resize_interpolation: Optional[str] = None @dataclass @@ -199,7 +197,6 @@ class ConfigSanitizer: "caption_prefix": str, "caption_suffix": str, "custom_attributes": dict, - "system_prompt": str, "resize_interpolation": str, } # DO means DropOut @@ -246,7 +243,6 @@ class ConfigSanitizer: "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, - "system_prompt": str, "resize_interpolation": str, } @@ -534,7 +530,6 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu resolution: {(dataset.width, dataset.height)} resize_interpolation: {dataset.resize_interpolation} enable_bucket: {dataset.enable_bucket} - system_prompt: {dataset.system_prompt} """) if dataset.enable_bucket: @@ -569,7 +564,6 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu alpha_mask: {subset.alpha_mask} resize_interpolation: {subset.resize_interpolation} custom_attributes: {subset.custom_attributes} - system_prompt: {subset.system_prompt} """), " ") if is_dreambooth: diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index d9e93f53..3d86dbef 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -218,8 +218,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy) assert isinstance(tokenize_strategy, LuminaTokenizeStrategy) - system_prompt_special_token = "" - captions = [f"{info.system_prompt} {system_prompt_special_token} " if info.system_prompt else "" + info.caption for info in batch] + captions = [info.caption for info in batch] if self.is_weighted: tokens, attention_masks, weights_list = ( diff --git a/lumina_train.py b/lumina_train.py index 330d0093..4b733c9e 100644 --- a/lumina_train.py +++ b/lumina_train.py @@ -266,12 +266,14 @@ def train(args): strategy_base.TextEncodingStrategy.get_strategy() ) + system_prompt_special_token = "" + system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in prompts: for p in [ - prompt_dict.get("prompt", ""), + system_prompt + prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", ""), ]: if p not in sample_prompts_te_outputs: diff --git a/lumina_train_network.py b/lumina_train_network.py index e1b45ac7..037ddac6 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -58,7 +58,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): torch.device("cpu"), disable_mmap=args.disable_mmap_load_safetensors, use_flash_attn=args.use_flash_attn, - use_sage_attn=args.use_sage_attn + use_sage_attn=args.use_sage_attn, ) if args.fp8_base: @@ -75,7 +75,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): model.to(torch.float8_e4m3fn) if args.blocks_to_swap: - logger.info(f'Lumina 2: Enabling block swap: {args.blocks_to_swap}') + logger.info(f"Lumina 2: Enabling block swap: {args.blocks_to_swap}") model.enable_block_swap(args.blocks_to_swap, accelerator.device) self.is_swapping_blocks = True @@ -157,13 +157,13 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) system_prompt_special_token = "" - system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" + system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" sample_prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in sample_prompts: prompts = [ - prompt_dict.get("prompt", ""), + system_prompt + prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", ""), ] for i, prompt in enumerate(prompts): @@ -371,7 +371,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): accelerator.unwrap_model(unet).prepare_block_swap_before_forward() - def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() train_util.add_dit_training_arguments(parser) diff --git a/tests/library/test_strategy_lumina.py b/tests/library/test_strategy_lumina.py index 18e196bf..aca16347 100644 --- a/tests/library/test_strategy_lumina.py +++ b/tests/library/test_strategy_lumina.py @@ -126,13 +126,12 @@ def test_lumina_text_encoder_outputs_caching_strategy(): # Create a mock class for ImageInfo class MockImageInfo: - def __init__(self, caption, system_prompt, cache_path): + def __init__(self, caption, cache_path): self.caption = caption - self.system_prompt = system_prompt self.text_encoder_outputs_npz = cache_path # Create a sample input info - image_info = MockImageInfo("Test caption", "", cache_file) + image_info = MockImageInfo("Test caption", cache_file) # Simulate a batch batch = [image_info]