Update T5 attention mask handling in FLUX

This commit is contained in:
Kohya S
2024-08-21 08:02:33 +09:00
parent 6ab48b09d8
commit 7e459c00b2
7 changed files with 101 additions and 50 deletions

View File

@@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
Aug 21, 2024:
The specification of `--apply_t5_attn_mask` has been changed. Previously, the T5 output was zero-padded, but now, two steps are taken: "1. Apply mask when encoding T5" and "2. Apply mask in the attention of Double Block". Fine tuning, LoRA training, and inference in `flux_mini_inference.py` have been changed.
Aug 20, 2024 (update 3):
__Experimental__ The multi-resolution training is now supported with caching latents to disk.

View File

@@ -70,12 +70,22 @@ def denoise(
vec: torch.Tensor,
timesteps: list[float],
guidance: float = 4.0,
t5_attn_mask: Optional[torch.Tensor] = None,
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
img = img + (t_prev - t_curr) * pred
@@ -92,6 +102,7 @@ def do_sample(
txt_ids: torch.Tensor,
num_steps: int,
guidance: float,
t5_attn_mask: Optional[torch.Tensor],
is_schnell: bool,
device: torch.device,
flux_dtype: torch.dtype,
@@ -101,10 +112,14 @@ def do_sample(
# denoise initial noise
if accelerator:
with accelerator.autocast(), torch.no_grad():
x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance)
x = denoise(
model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask
)
else:
with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance)
x = denoise(
model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask
)
return x
@@ -156,14 +171,14 @@ def generate_image(
clip_l.to(clip_l_dtype)
t5xxl.to(t5xxl_dtype)
with accelerator.autocast():
_, t5_out, txt_ids = encoding_strategy.encode_tokens(
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
else:
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
l_pooled, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
_, t5_out, txt_ids = encoding_strategy.encode_tokens(
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
@@ -186,7 +201,11 @@ def generate_image(
steps = 4 if is_schnell else 50
img_ids = img_ids.to(device)
x = do_sample(accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, is_schnell, device, flux_dtype)
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
x = do_sample(
accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, t5_attn_mask, is_schnell, device, flux_dtype
)
if args.offload:
model = model.cpu()
# del model

View File

@@ -610,7 +610,10 @@ def train(args):
guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device)
# call model
l_pooled, t5_out, txt_ids = text_encoder_conds
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None
with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = flux(
@@ -621,6 +624,7 @@ def train(args):
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# unpack latents

View File

@@ -233,11 +233,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.flux_lower = flux_lower
self.target_device = device
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None):
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
self.flux_lower.to("cpu")
clean_memory_on_device(self.target_device)
self.flux_upper.to(self.target_device)
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance)
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
self.flux_upper.to("cpu")
clean_memory_on_device(self.target_device)
self.flux_lower.to(self.target_device)
@@ -300,10 +300,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
guidance_vec.requires_grad_(True)
# Predict the noise residual
l_pooled, t5_out, txt_ids = text_encoder_conds
# print(
# f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}"
# )
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None
if not args.split_mode:
# normal forward
@@ -317,6 +316,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
else:
# split forward to reduce memory usage
@@ -337,6 +337,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# move flux upper back to cpu, and then move flux lower to gpu

View File

@@ -440,10 +440,10 @@ configs = {
# region math
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = rearrange(x, "B H L D -> B L (H D)")
return x
@@ -607,11 +607,7 @@ class SelfAttention(nn.Module):
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
# self.gradient_checkpointing = False
# def enable_gradient_checkpointing(self):
# self.gradient_checkpointing = True
# this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
@@ -620,12 +616,6 @@ class SelfAttention(nn.Module):
x = self.proj(x)
return x
# def forward(self, *args, **kwargs):
# if self.training and self.gradient_checkpointing:
# return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
# else:
# return self._forward(*args, **kwargs)
@dataclass
class ModulationOut:
@@ -690,7 +680,9 @@ class DoubleStreamBlock(nn.Module):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
def _forward(
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
@@ -713,7 +705,18 @@ class DoubleStreamBlock(nn.Module):
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe)
# make attention mask if not None
attn_mask = None
if txt_attention_mask is not None:
attn_mask = txt_attention_mask # b, seq_len
attn_mask = torch.cat(
(attn_mask, torch.ones(attn_mask.shape[0], img.shape[1]).to(attn_mask.device)), dim=1
) # b, seq_len + img_len
# broadcast attn_mask to all heads
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img blocks
@@ -725,10 +728,12 @@ class DoubleStreamBlock(nn.Module):
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
return img, txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
def forward(
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
) -> tuple[Tensor, Tensor]:
if self.training and self.gradient_checkpointing:
if not self.cpu_offload_checkpointing:
return checkpoint(self._forward, img, txt, vec, pe, use_reentrant=False)
return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False)
# cpu offload checkpointing
def create_custom_forward(func):
@@ -739,10 +744,10 @@ class DoubleStreamBlock(nn.Module):
return custom_forward
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe)
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask)
else:
return self._forward(img, txt, vec, pe)
return self._forward(img, txt, vec, pe, txt_attention_mask)
# def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
# if self.training and self.gradient_checkpointing:
@@ -992,6 +997,7 @@ class Flux(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -1011,7 +1017,7 @@ class Flux(nn.Module):
if not self.double_blocks_to_swap:
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
else:
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
for block_idx in range(self.double_blocks_to_swap):
@@ -1033,7 +1039,7 @@ class Flux(nn.Module):
block.to(self.device) # move to cuda
# print(f"Moved double block {block_idx} to cuda.")
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if moving:
self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
@@ -1164,6 +1170,7 @@ class FluxUpper(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -1182,7 +1189,7 @@ class FluxUpper(nn.Module):
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
return img, txt, vec, pe

View File

@@ -190,9 +190,10 @@ def sample_image_inference(
te_outputs = sample_prompts_te_outputs[prompt]
else:
tokens_and_masks = tokenize_strategy.tokenize(prompt)
# strategy has apply_t5_attn_mask option
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
l_pooled, t5_out, txt_ids = te_outputs
l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs
# sample image
weight_dtype = ae.dtype # TOFO give dtype as argument
@@ -208,9 +209,10 @@ def sample_image_inference(
)
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
with accelerator.autocast(), torch.no_grad():
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale)
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask)
x = x.float()
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
@@ -289,12 +291,22 @@ def denoise(
vec: torch.Tensor,
timesteps: list[float],
guidance: float = 4.0,
t5_attn_mask: Optional[torch.Tensor] = None,
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
img = img + (t_prev - t_curr) * pred
@@ -498,7 +510,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--apply_t5_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する",
help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
)
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"

View File

@@ -64,22 +64,25 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
l_tokens, t5_tokens = tokens[:2]
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
# clip_l is None when using T5 only
if clip_l is not None and l_tokens is not None:
l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"]
else:
l_pooled = None
# t5xxl is None when using CLIP only
if t5xxl is not None and t5_tokens is not None:
# t5_out is [b, max length, 4096]
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True)
if apply_t5_attn_mask:
t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device)
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True)
# if zero_pad_t5_output:
# t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
else:
t5_out = None
txt_ids = None
return [l_pooled, t5_out, txt_ids]
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
@@ -115,6 +118,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return False
if "txt_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -129,12 +134,12 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
l_pooled = data["l_pooled"]
t5_out = data["t5_out"]
txt_ids = data["txt_ids"]
t5_attn_mask = data["t5_attn_mask"]
if self.apply_t5_attn_mask:
t5_attn_mask = data["t5_attn_mask"]
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask)
return [l_pooled, t5_out, txt_ids]
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
@@ -145,7 +150,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# attn_mask is not applied when caching to disk: it is applied when loading from disk
l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens(
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk
)
@@ -159,15 +164,15 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
l_pooled = l_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
txt_ids = txt_ids.cpu().numpy()
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
for i, info in enumerate(infos):
l_pooled_i = l_pooled[i]
t5_out_i = t5_out[i]
txt_ids_i = txt_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
if self.cache_to_disk:
t5_attn_mask = tokens_and_masks[2]
t5_attn_mask_i = t5_attn_mask[i].cpu().numpy()
np.savez(
info.text_encoder_outputs_npz,
l_pooled=l_pooled_i,
@@ -176,7 +181,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
t5_attn_mask=t5_attn_mask_i,
)
else:
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i)
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
class FluxLatentsCachingStrategy(LatentsCachingStrategy):