diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index d3edd262..7ade6c1b 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -227,7 +227,7 @@ def sample_image_inference( ) timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) - gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device) if args.apply_gemma2_attn_mask else None + gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device) # if controlnet_image is not None: # controlnet_image = Image.open(controlnet_image).convert("RGB") @@ -511,11 +511,6 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): help="maximum token length for Gemma2. if omitted, 256 for schnell and 512 for dev" " / Gemma2の最大トークン長。省略された場合、schnellの場合は256、devの場合は512", ) - parser.add_argument( - "--apply_gemma2_attn_mask", - action="store_true", - help="apply attention mask to Gemma2 encode and NextDIT double blocks / Gemma2エンコードとNextDITダブルブロックにアテンションマスクを適用する", - ) parser.add_argument( "--guidance_scale", diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index 6feea387..209f62a0 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -47,7 +47,7 @@ class LuminaTokenizeStrategy(TokenizeStrategy): pad_to_multiple_of=8, truncation=True, ) - return encodings.input_ids, encodings.attention_mask + return [encodings.input_ids, encodings.attention_mask] def tokenize_with_weights( self, text: str | List[str] @@ -59,47 +59,36 @@ class LuminaTokenizeStrategy(TokenizeStrategy): class LuminaTextEncodingStrategy(TextEncodingStrategy): - def __init__(self, apply_gemma2_attn_mask: Optional[bool] = None) -> None: + def __init__(self) -> None: super().__init__() - self.apply_gemma2_attn_mask = apply_gemma2_attn_mask def encode_tokens( self, tokenize_strategy: TokenizeStrategy, models: List[Any], - tokens: torch.Tensor, - attention_masks: torch.Tensor, - apply_gemma2_attn_mask: Optional[bool] = None, - ) -> torch.Tensor: - if apply_gemma2_attn_mask is None: - apply_gemma2_attn_mask = self.apply_gemma2_attn_mask - + tokens: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: text_encoder = models[0] - - # Create position IDs - position_ids = attention_masks.cumsum(-1) - 1 - position_ids.masked_fill_(attention_masks == 0, 1) + input_ids, attention_masks = tokens outputs = text_encoder( - input_ids=tokens.to(text_encoder.device), - attention_mask=attention_masks.to(text_encoder.device) if apply_gemma2_attn_mask else None, - position_ids=position_ids.to(text_encoder.device), + input_ids=input_ids.to(text_encoder.device), + attention_mask=attention_masks.to(text_encoder.device), output_hidden_states=True, return_dict=True, ) - return outputs.hidden_states[-2] + return outputs.hidden_states[-2], input_ids, attention_masks def encode_tokens_with_weights( self, tokenize_strategy: TokenizeStrategy, models: List[Any], - tokens: torch.Tensor, + tokens: List[torch.Tensor], weights_list: List[torch.Tensor], - attention_masks: torch.Tensor - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # For simplicity, use uniform weighting - return self.encode_tokens(tokenize_strategy, models, tokens, attention_masks) + return self.encode_tokens(tokenize_strategy, models, tokens) class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): @@ -111,7 +100,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False, - apply_gemma2_attn_mask: bool = False, ) -> None: super().__init__( cache_to_disk, @@ -119,7 +107,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) skip_disk_cache_validity_check, is_partial, ) - self.apply_gemma2_attn_mask = apply_gemma2_attn_mask def get_outputs_npz_path(self, image_abs_path: str) -> str: return ( @@ -146,7 +133,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) if "apply_gemma2_attn_mask" not in npz: return False npz_apply_gemma2_attn_mask = npz["apply_gemma2_attn_mask"] - if npz_apply_gemma2_attn_mask != self.apply_gemma2_attn_mask: + if not npz_apply_gemma2_attn_mask: return False except Exception as e: logger.error(f"Error loading file: {npz_path}") @@ -174,18 +161,18 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) captions = [info.caption for info in infos] if self.is_weighted: - tokens, attention_masks, weights_list = tokenize_strategy.tokenize_with_weights( + tokens, weights_list = tokenize_strategy.tokenize_with_weights( captions ) with torch.no_grad(): - hidden_state = lumina_text_encoding_strategy.encode_tokens_with_weights( - tokenize_strategy, models, tokens, weights_list, attention_masks + hidden_state, input_ids, attention_masks = lumina_text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, models, tokens, weights_list ) else: - tokens, attention_masks = tokenize_strategy.tokenize(captions) + tokens = tokenize_strategy.tokenize(captions) with torch.no_grad(): - hidden_state = lumina_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens, attention_masks + hidden_state, input_ids, attention_masks = lumina_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens ) if hidden_state.dtype != torch.float32: @@ -200,7 +187,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) hidden_state_i = hidden_state[i] attention_mask_i = attention_mask[i] input_ids_i = input_ids[i] - apply_gemma2_attn_mask_i = self.apply_gemma2_attn_mask if self.cache_to_disk: np.savez( @@ -208,7 +194,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) hidden_state=hidden_state_i, attention_mask=attention_mask_i, input_ids=input_ids_i, - apply_gemma2_attn_mask=apply_gemma2_attn_mask_i, + apply_gemma2_attn_mask=True ) else: info.text_encoder_outputs = [hidden_state_i, attention_mask_i, input_ids_i] diff --git a/lumina_train_network.py b/lumina_train_network.py index 3d0c7062..00c81bce 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -64,7 +64,11 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): if args.fp8_base: # check dtype of model - if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: + if ( + model.dtype == torch.float8_e4m3fnuz + or model.dtype == torch.float8_e5m2 + or model.dtype == torch.float8_e5m2fnuz + ): raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") elif model.dtype == torch.float8_e4m3fn: logger.info("Loaded fp8 Lumina 2 model") @@ -80,13 +84,9 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # model.enable_block_swap(args.blocks_to_swap, accelerator.device) # self.is_swapping_blocks = True - gemma2 = lumina_util.load_gemma2( - args.gemma2, weight_dtype, "cpu" - ) + gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu") gemma2.eval() - ae = lumina_util.load_ae( - args.ae, weight_dtype, "cpu" - ) + ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu") return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model @@ -104,7 +104,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): ) def get_text_encoding_strategy(self, args): - return strategy_lumina.LuminaTextEncodingStrategy(args.apply_gemma2_attn_mask) + return strategy_lumina.LuminaTextEncodingStrategy() def get_text_encoders_train_flags(self, args, text_encoders): return [self.train_gemma2] @@ -117,7 +117,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): args.text_encoder_batch_size, args.skip_cache_check, is_partial=self.train_gemma2, - apply_gemma2_attn_mask=args.apply_gemma2_attn_mask, ) else: return None @@ -144,11 +143,15 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # When TE is not be trained, it will not be prepared so we need to use explicit autocast logger.info("move text encoders to gpu") - text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[0].to( + accelerator.device, dtype=weight_dtype + ) # always not fp8 if text_encoders[0].dtype == torch.float8_e4m3fn: # if we load fp8 weights, the model is already fp8, so we use it as is - self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + self.prepare_text_encoder_fp8( + 1, text_encoders[1], text_encoders[1].dtype, weight_dtype + ) else: # otherwise, we need to convert it to target dtype text_encoders[0].to(weight_dtype) @@ -158,21 +161,39 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # cache sample prompts if args.sample_prompts is not None: - logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + logger.info( + f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}" + ) - tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() - text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy = ( + strategy_base.TokenizeStrategy.get_strategy() + ) + text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = ( + strategy_base.TextEncodingStrategy.get_strategy() + ) prompts = train_util.load_prompts(args.sample_prompts) - sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + sample_prompts_te_outputs = ( + {} + ) # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in prompts: - for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + for p in [ + prompt_dict.get("prompt", ""), + prompt_dict.get("negative_prompt", ""), + ]: if p not in sample_prompts_te_outputs: - logger.info(f"cache Text Encoder outputs for prompt: {p}") + logger.info( + f"cache Text Encoder outputs for prompt: {p}" + ) tokens_and_masks = tokenize_strategy.tokenize(p) - sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( - tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask + sample_prompts_te_outputs[p] = ( + text_encoding_strategy.encode_tokens( + tokenize_strategy, + text_encoders, + tokens_and_masks, + args.apply_t5_attn_mask, + ) ) self.sample_prompts_te_outputs = sample_prompts_te_outputs @@ -261,10 +282,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # May not need to pack/unpack? # pack latents and get img_ids - 这部分可以保留因为NextDiT也需要packed格式的输入 # packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input) - # packed_latent_height, packed_latent_width = ( - # noisy_model_input.shape[2] // 2, - # noisy_model_input.shape[3] // 2, - # ) # ensure the hidden state will require grad if args.gradient_checkpointing: @@ -274,16 +291,18 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): t.requires_grad_(True) # Unpack Gemma2 outputs - gemma2_hidden_states, gemma2_attn_mask, input_ids = text_encoder_conds + gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds def call_dit(img, gemma2_hidden_states, timesteps, gemma2_attn_mask): with torch.set_grad_enabled(is_train), accelerator.autocast(): # NextDiT forward expects (x, t, cap_feats, cap_mask) model_pred = unet( - x=img, # image latents (B, C, H, W) + x=img, # image latents (B, C, H, W) t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features - cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask + cap_mask=gemma2_attn_mask.to( + dtype=torch.int32 + ), # Gemma2的attention mask ) return model_pred @@ -326,13 +345,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): gemma2_hidden_states=gemma2_hidden_states[ diff_output_pr_indices ], - input_ids=input_ids[diff_output_pr_indices], timesteps=timesteps[diff_output_pr_indices], - gemma2_attn_mask=( - gemma2_attn_mask[diff_output_pr_indices] - if gemma2_attn_mask is not None - else None - ), + gemma2_attn_mask=(gemma2_attn_mask[diff_output_pr_indices]), ) network.set_multiplier(1.0) @@ -358,7 +372,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): ) def update_metadata(self, metadata, args): - metadata["ss_apply_gemma2_attn_mask"] = args.apply_gemma2_attn_mask metadata["ss_weighting_scheme"] = args.weighting_scheme metadata["ss_logit_mean"] = args.logit_mean metadata["ss_logit_std"] = args.logit_std @@ -373,7 +386,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): - text_encoder.model.embed_tokens.requires_grad_(True) + text_encoder.embed_tokens.requires_grad_(True) def prepare_text_encoder_fp8( self, index, text_encoder, te_weight_dtype, weight_dtype @@ -382,7 +395,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}" ) text_encoder.to(te_weight_dtype) # fp8 - text_encoder.model.embed_tokens.to(dtype=weight_dtype) + text_encoder.embed_tokens.to(dtype=weight_dtype) def prepare_unet_with_accelerator( self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module