mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 23:01:22 +00:00
update for always use gemma2 mask
This commit is contained in:
@@ -227,7 +227,7 @@ def sample_image_inference(
|
||||
)
|
||||
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True)
|
||||
img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
||||
gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device) if args.apply_gemma2_attn_mask else None
|
||||
gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device)
|
||||
|
||||
# if controlnet_image is not None:
|
||||
# controlnet_image = Image.open(controlnet_image).convert("RGB")
|
||||
@@ -511,11 +511,6 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
|
||||
help="maximum token length for Gemma2. if omitted, 256 for schnell and 512 for dev"
|
||||
" / Gemma2の最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply_gemma2_attn_mask",
|
||||
action="store_true",
|
||||
help="apply attention mask to Gemma2 encode and NextDIT double blocks / Gemma2エンコードとNextDITダブルブロックにアテンションマスクを適用する",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--guidance_scale",
|
||||
|
||||
@@ -47,7 +47,7 @@ class LuminaTokenizeStrategy(TokenizeStrategy):
|
||||
pad_to_multiple_of=8,
|
||||
truncation=True,
|
||||
)
|
||||
return encodings.input_ids, encodings.attention_mask
|
||||
return [encodings.input_ids, encodings.attention_mask]
|
||||
|
||||
def tokenize_with_weights(
|
||||
self, text: str | List[str]
|
||||
@@ -59,47 +59,36 @@ class LuminaTokenizeStrategy(TokenizeStrategy):
|
||||
|
||||
|
||||
class LuminaTextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self, apply_gemma2_attn_mask: Optional[bool] = None) -> None:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.apply_gemma2_attn_mask = apply_gemma2_attn_mask
|
||||
|
||||
def encode_tokens(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: torch.Tensor,
|
||||
attention_masks: torch.Tensor,
|
||||
apply_gemma2_attn_mask: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
if apply_gemma2_attn_mask is None:
|
||||
apply_gemma2_attn_mask = self.apply_gemma2_attn_mask
|
||||
|
||||
tokens: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
text_encoder = models[0]
|
||||
|
||||
# Create position IDs
|
||||
position_ids = attention_masks.cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_masks == 0, 1)
|
||||
input_ids, attention_masks = tokens
|
||||
|
||||
outputs = text_encoder(
|
||||
input_ids=tokens.to(text_encoder.device),
|
||||
attention_mask=attention_masks.to(text_encoder.device) if apply_gemma2_attn_mask else None,
|
||||
position_ids=position_ids.to(text_encoder.device),
|
||||
input_ids=input_ids.to(text_encoder.device),
|
||||
attention_mask=attention_masks.to(text_encoder.device),
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return outputs.hidden_states[-2]
|
||||
return outputs.hidden_states[-2], input_ids, attention_masks
|
||||
|
||||
def encode_tokens_with_weights(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: torch.Tensor,
|
||||
tokens: List[torch.Tensor],
|
||||
weights_list: List[torch.Tensor],
|
||||
attention_masks: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# For simplicity, use uniform weighting
|
||||
return self.encode_tokens(tokenize_strategy, models, tokens, attention_masks)
|
||||
return self.encode_tokens(tokenize_strategy, models, tokens)
|
||||
|
||||
|
||||
class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
@@ -111,7 +100,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
is_partial: bool = False,
|
||||
apply_gemma2_attn_mask: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cache_to_disk,
|
||||
@@ -119,7 +107,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
skip_disk_cache_validity_check,
|
||||
is_partial,
|
||||
)
|
||||
self.apply_gemma2_attn_mask = apply_gemma2_attn_mask
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return (
|
||||
@@ -146,7 +133,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
if "apply_gemma2_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_gemma2_attn_mask = npz["apply_gemma2_attn_mask"]
|
||||
if npz_apply_gemma2_attn_mask != self.apply_gemma2_attn_mask:
|
||||
if not npz_apply_gemma2_attn_mask:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
@@ -174,18 +161,18 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
captions = [info.caption for info in infos]
|
||||
|
||||
if self.is_weighted:
|
||||
tokens, attention_masks, weights_list = tokenize_strategy.tokenize_with_weights(
|
||||
tokens, weights_list = tokenize_strategy.tokenize_with_weights(
|
||||
captions
|
||||
)
|
||||
with torch.no_grad():
|
||||
hidden_state = lumina_text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy, models, tokens, weights_list, attention_masks
|
||||
hidden_state, input_ids, attention_masks = lumina_text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy, models, tokens, weights_list
|
||||
)
|
||||
else:
|
||||
tokens, attention_masks = tokenize_strategy.tokenize(captions)
|
||||
tokens = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
hidden_state = lumina_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens, attention_masks
|
||||
hidden_state, input_ids, attention_masks = lumina_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens
|
||||
)
|
||||
|
||||
if hidden_state.dtype != torch.float32:
|
||||
@@ -200,7 +187,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
hidden_state_i = hidden_state[i]
|
||||
attention_mask_i = attention_mask[i]
|
||||
input_ids_i = input_ids[i]
|
||||
apply_gemma2_attn_mask_i = self.apply_gemma2_attn_mask
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
@@ -208,7 +194,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
hidden_state=hidden_state_i,
|
||||
attention_mask=attention_mask_i,
|
||||
input_ids=input_ids_i,
|
||||
apply_gemma2_attn_mask=apply_gemma2_attn_mask_i,
|
||||
apply_gemma2_attn_mask=True
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = [hidden_state_i, attention_mask_i, input_ids_i]
|
||||
|
||||
@@ -64,7 +64,11 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
if args.fp8_base:
|
||||
# check dtype of model
|
||||
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
|
||||
if (
|
||||
model.dtype == torch.float8_e4m3fnuz
|
||||
or model.dtype == torch.float8_e5m2
|
||||
or model.dtype == torch.float8_e5m2fnuz
|
||||
):
|
||||
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
||||
elif model.dtype == torch.float8_e4m3fn:
|
||||
logger.info("Loaded fp8 Lumina 2 model")
|
||||
@@ -80,13 +84,9 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
# model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
# self.is_swapping_blocks = True
|
||||
|
||||
gemma2 = lumina_util.load_gemma2(
|
||||
args.gemma2, weight_dtype, "cpu"
|
||||
)
|
||||
gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu")
|
||||
gemma2.eval()
|
||||
ae = lumina_util.load_ae(
|
||||
args.ae, weight_dtype, "cpu"
|
||||
)
|
||||
ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu")
|
||||
|
||||
return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model
|
||||
|
||||
@@ -104,7 +104,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
)
|
||||
|
||||
def get_text_encoding_strategy(self, args):
|
||||
return strategy_lumina.LuminaTextEncodingStrategy(args.apply_gemma2_attn_mask)
|
||||
return strategy_lumina.LuminaTextEncodingStrategy()
|
||||
|
||||
def get_text_encoders_train_flags(self, args, text_encoders):
|
||||
return [self.train_gemma2]
|
||||
@@ -117,7 +117,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
is_partial=self.train_gemma2,
|
||||
apply_gemma2_attn_mask=args.apply_gemma2_attn_mask,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
@@ -144,11 +143,15 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
logger.info("move text encoders to gpu")
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
||||
text_encoders[0].to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
) # always not fp8
|
||||
|
||||
if text_encoders[0].dtype == torch.float8_e4m3fn:
|
||||
# if we load fp8 weights, the model is already fp8, so we use it as is
|
||||
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
||||
self.prepare_text_encoder_fp8(
|
||||
1, text_encoders[1], text_encoders[1].dtype, weight_dtype
|
||||
)
|
||||
else:
|
||||
# otherwise, we need to convert it to target dtype
|
||||
text_encoders[0].to(weight_dtype)
|
||||
@@ -158,21 +161,39 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
# cache sample prompts
|
||||
if args.sample_prompts is not None:
|
||||
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
|
||||
logger.info(
|
||||
f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}"
|
||||
)
|
||||
|
||||
tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy = (
|
||||
strategy_base.TokenizeStrategy.get_strategy()
|
||||
)
|
||||
text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = (
|
||||
strategy_base.TextEncodingStrategy.get_strategy()
|
||||
)
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
sample_prompts_te_outputs = (
|
||||
{}
|
||||
) # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
||||
for p in [
|
||||
prompt_dict.get("prompt", ""),
|
||||
prompt_dict.get("negative_prompt", ""),
|
||||
]:
|
||||
if p not in sample_prompts_te_outputs:
|
||||
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
||||
logger.info(
|
||||
f"cache Text Encoder outputs for prompt: {p}"
|
||||
)
|
||||
tokens_and_masks = tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
|
||||
sample_prompts_te_outputs[p] = (
|
||||
text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
text_encoders,
|
||||
tokens_and_masks,
|
||||
args.apply_t5_attn_mask,
|
||||
)
|
||||
)
|
||||
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
||||
|
||||
@@ -261,10 +282,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
# May not need to pack/unpack?
|
||||
# pack latents and get img_ids - 这部分可以保留因为NextDiT也需要packed格式的输入
|
||||
# packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input)
|
||||
# packed_latent_height, packed_latent_width = (
|
||||
# noisy_model_input.shape[2] // 2,
|
||||
# noisy_model_input.shape[3] // 2,
|
||||
# )
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
@@ -274,16 +291,18 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Unpack Gemma2 outputs
|
||||
gemma2_hidden_states, gemma2_attn_mask, input_ids = text_encoder_conds
|
||||
gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds
|
||||
|
||||
def call_dit(img, gemma2_hidden_states, timesteps, gemma2_attn_mask):
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# NextDiT forward expects (x, t, cap_feats, cap_mask)
|
||||
model_pred = unet(
|
||||
x=img, # image latents (B, C, H, W)
|
||||
x=img, # image latents (B, C, H, W)
|
||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
||||
cap_mask=gemma2_attn_mask.to(
|
||||
dtype=torch.int32
|
||||
), # Gemma2的attention mask
|
||||
)
|
||||
return model_pred
|
||||
|
||||
@@ -326,13 +345,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
gemma2_hidden_states=gemma2_hidden_states[
|
||||
diff_output_pr_indices
|
||||
],
|
||||
input_ids=input_ids[diff_output_pr_indices],
|
||||
timesteps=timesteps[diff_output_pr_indices],
|
||||
gemma2_attn_mask=(
|
||||
gemma2_attn_mask[diff_output_pr_indices]
|
||||
if gemma2_attn_mask is not None
|
||||
else None
|
||||
),
|
||||
gemma2_attn_mask=(gemma2_attn_mask[diff_output_pr_indices]),
|
||||
)
|
||||
network.set_multiplier(1.0)
|
||||
|
||||
@@ -358,7 +372,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
)
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_apply_gemma2_attn_mask"] = args.apply_gemma2_attn_mask
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
metadata["ss_logit_mean"] = args.logit_mean
|
||||
metadata["ss_logit_std"] = args.logit_std
|
||||
@@ -373,7 +386,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||
|
||||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||
text_encoder.model.embed_tokens.requires_grad_(True)
|
||||
text_encoder.embed_tokens.requires_grad_(True)
|
||||
|
||||
def prepare_text_encoder_fp8(
|
||||
self, index, text_encoder, te_weight_dtype, weight_dtype
|
||||
@@ -382,7 +395,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}"
|
||||
)
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
text_encoder.model.embed_tokens.to(dtype=weight_dtype)
|
||||
text_encoder.embed_tokens.to(dtype=weight_dtype)
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
|
||||
Reference in New Issue
Block a user