update for always use gemma2 mask

This commit is contained in:
sdbds
2025-02-17 19:00:18 +08:00
parent bb7bae5dff
commit aa36c48685
3 changed files with 68 additions and 74 deletions

View File

@@ -227,7 +227,7 @@ def sample_image_inference(
)
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True)
img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device) if args.apply_gemma2_attn_mask else None
gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device)
# if controlnet_image is not None:
# controlnet_image = Image.open(controlnet_image).convert("RGB")
@@ -511,11 +511,6 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
help="maximum token length for Gemma2. if omitted, 256 for schnell and 512 for dev"
" / Gemma2の最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
)
parser.add_argument(
"--apply_gemma2_attn_mask",
action="store_true",
help="apply attention mask to Gemma2 encode and NextDIT double blocks / Gemma2エンコードとNextDITダブルブロックにアテンションマスクを適用する",
)
parser.add_argument(
"--guidance_scale",

View File

@@ -47,7 +47,7 @@ class LuminaTokenizeStrategy(TokenizeStrategy):
pad_to_multiple_of=8,
truncation=True,
)
return encodings.input_ids, encodings.attention_mask
return [encodings.input_ids, encodings.attention_mask]
def tokenize_with_weights(
self, text: str | List[str]
@@ -59,47 +59,36 @@ class LuminaTokenizeStrategy(TokenizeStrategy):
class LuminaTextEncodingStrategy(TextEncodingStrategy):
def __init__(self, apply_gemma2_attn_mask: Optional[bool] = None) -> None:
def __init__(self) -> None:
super().__init__()
self.apply_gemma2_attn_mask = apply_gemma2_attn_mask
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: torch.Tensor,
attention_masks: torch.Tensor,
apply_gemma2_attn_mask: Optional[bool] = None,
) -> torch.Tensor:
if apply_gemma2_attn_mask is None:
apply_gemma2_attn_mask = self.apply_gemma2_attn_mask
tokens: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
text_encoder = models[0]
# Create position IDs
position_ids = attention_masks.cumsum(-1) - 1
position_ids.masked_fill_(attention_masks == 0, 1)
input_ids, attention_masks = tokens
outputs = text_encoder(
input_ids=tokens.to(text_encoder.device),
attention_mask=attention_masks.to(text_encoder.device) if apply_gemma2_attn_mask else None,
position_ids=position_ids.to(text_encoder.device),
input_ids=input_ids.to(text_encoder.device),
attention_mask=attention_masks.to(text_encoder.device),
output_hidden_states=True,
return_dict=True,
)
return outputs.hidden_states[-2]
return outputs.hidden_states[-2], input_ids, attention_masks
def encode_tokens_with_weights(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: torch.Tensor,
tokens: List[torch.Tensor],
weights_list: List[torch.Tensor],
attention_masks: torch.Tensor
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# For simplicity, use uniform weighting
return self.encode_tokens(tokenize_strategy, models, tokens, attention_masks)
return self.encode_tokens(tokenize_strategy, models, tokens)
class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
@@ -111,7 +100,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_gemma2_attn_mask: bool = False,
) -> None:
super().__init__(
cache_to_disk,
@@ -119,7 +107,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
skip_disk_cache_validity_check,
is_partial,
)
self.apply_gemma2_attn_mask = apply_gemma2_attn_mask
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return (
@@ -146,7 +133,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
if "apply_gemma2_attn_mask" not in npz:
return False
npz_apply_gemma2_attn_mask = npz["apply_gemma2_attn_mask"]
if npz_apply_gemma2_attn_mask != self.apply_gemma2_attn_mask:
if not npz_apply_gemma2_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
@@ -174,18 +161,18 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
captions = [info.caption for info in infos]
if self.is_weighted:
tokens, attention_masks, weights_list = tokenize_strategy.tokenize_with_weights(
tokens, weights_list = tokenize_strategy.tokenize_with_weights(
captions
)
with torch.no_grad():
hidden_state = lumina_text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, models, tokens, weights_list, attention_masks
hidden_state, input_ids, attention_masks = lumina_text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, models, tokens, weights_list
)
else:
tokens, attention_masks = tokenize_strategy.tokenize(captions)
tokens = tokenize_strategy.tokenize(captions)
with torch.no_grad():
hidden_state = lumina_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens, attention_masks
hidden_state, input_ids, attention_masks = lumina_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens
)
if hidden_state.dtype != torch.float32:
@@ -200,7 +187,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
hidden_state_i = hidden_state[i]
attention_mask_i = attention_mask[i]
input_ids_i = input_ids[i]
apply_gemma2_attn_mask_i = self.apply_gemma2_attn_mask
if self.cache_to_disk:
np.savez(
@@ -208,7 +194,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
hidden_state=hidden_state_i,
attention_mask=attention_mask_i,
input_ids=input_ids_i,
apply_gemma2_attn_mask=apply_gemma2_attn_mask_i,
apply_gemma2_attn_mask=True
)
else:
info.text_encoder_outputs = [hidden_state_i, attention_mask_i, input_ids_i]

View File

@@ -64,7 +64,11 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
if args.fp8_base:
# check dtype of model
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
if (
model.dtype == torch.float8_e4m3fnuz
or model.dtype == torch.float8_e5m2
or model.dtype == torch.float8_e5m2fnuz
):
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
elif model.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 Lumina 2 model")
@@ -80,13 +84,9 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
# model.enable_block_swap(args.blocks_to_swap, accelerator.device)
# self.is_swapping_blocks = True
gemma2 = lumina_util.load_gemma2(
args.gemma2, weight_dtype, "cpu"
)
gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu")
gemma2.eval()
ae = lumina_util.load_ae(
args.ae, weight_dtype, "cpu"
)
ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu")
return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model
@@ -104,7 +104,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
)
def get_text_encoding_strategy(self, args):
return strategy_lumina.LuminaTextEncodingStrategy(args.apply_gemma2_attn_mask)
return strategy_lumina.LuminaTextEncodingStrategy()
def get_text_encoders_train_flags(self, args, text_encoders):
return [self.train_gemma2]
@@ -117,7 +117,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=self.train_gemma2,
apply_gemma2_attn_mask=args.apply_gemma2_attn_mask,
)
else:
return None
@@ -144,11 +143,15 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
text_encoders[0].to(
accelerator.device, dtype=weight_dtype
) # always not fp8
if text_encoders[0].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
self.prepare_text_encoder_fp8(
1, text_encoders[1], text_encoders[1].dtype, weight_dtype
)
else:
# otherwise, we need to convert it to target dtype
text_encoders[0].to(weight_dtype)
@@ -158,21 +161,39 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
# cache sample prompts
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
logger.info(
f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}"
)
tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy = (
strategy_base.TokenizeStrategy.get_strategy()
)
text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = (
strategy_base.TextEncodingStrategy.get_strategy()
)
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
sample_prompts_te_outputs = (
{}
) # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
for p in [
prompt_dict.get("prompt", ""),
prompt_dict.get("negative_prompt", ""),
]:
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
logger.info(
f"cache Text Encoder outputs for prompt: {p}"
)
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
sample_prompts_te_outputs[p] = (
text_encoding_strategy.encode_tokens(
tokenize_strategy,
text_encoders,
tokens_and_masks,
args.apply_t5_attn_mask,
)
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
@@ -261,10 +282,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
# May not need to pack/unpack?
# pack latents and get img_ids - 这部分可以保留因为NextDiT也需要packed格式的输入
# packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input)
# packed_latent_height, packed_latent_width = (
# noisy_model_input.shape[2] // 2,
# noisy_model_input.shape[3] // 2,
# )
# ensure the hidden state will require grad
if args.gradient_checkpointing:
@@ -274,16 +291,18 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
t.requires_grad_(True)
# Unpack Gemma2 outputs
gemma2_hidden_states, gemma2_attn_mask, input_ids = text_encoder_conds
gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds
def call_dit(img, gemma2_hidden_states, timesteps, gemma2_attn_mask):
with torch.set_grad_enabled(is_train), accelerator.autocast():
# NextDiT forward expects (x, t, cap_feats, cap_mask)
model_pred = unet(
x=img, # image latents (B, C, H, W)
x=img, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
cap_mask=gemma2_attn_mask.to(
dtype=torch.int32
), # Gemma2的attention mask
)
return model_pred
@@ -326,13 +345,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
gemma2_hidden_states=gemma2_hidden_states[
diff_output_pr_indices
],
input_ids=input_ids[diff_output_pr_indices],
timesteps=timesteps[diff_output_pr_indices],
gemma2_attn_mask=(
gemma2_attn_mask[diff_output_pr_indices]
if gemma2_attn_mask is not None
else None
),
gemma2_attn_mask=(gemma2_attn_mask[diff_output_pr_indices]),
)
network.set_multiplier(1.0)
@@ -358,7 +372,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
)
def update_metadata(self, metadata, args):
metadata["ss_apply_gemma2_attn_mask"] = args.apply_gemma2_attn_mask
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
@@ -373,7 +386,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
text_encoder.model.embed_tokens.requires_grad_(True)
text_encoder.embed_tokens.requires_grad_(True)
def prepare_text_encoder_fp8(
self, index, text_encoder, te_weight_dtype, weight_dtype
@@ -382,7 +395,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}"
)
text_encoder.to(te_weight_dtype) # fp8
text_encoder.model.embed_tokens.to(dtype=weight_dtype)
text_encoder.embed_tokens.to(dtype=weight_dtype)
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module