From 7e459c00b2e142e40a9452341934c2eb9f70a172 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 21 Aug 2024 08:02:33 +0900 Subject: [PATCH] Update T5 attention mask handling in FLUX --- README.md | 3 +++ flux_minimal_inference.py | 33 +++++++++++++++++++----- flux_train.py | 6 ++++- flux_train_network.py | 13 +++++----- library/flux_models.py | 51 +++++++++++++++++++++---------------- library/flux_train_utils.py | 20 ++++++++++++--- library/strategy_flux.py | 25 ++++++++++-------- 7 files changed, 101 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index 1d44c9e5..43edbbed 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 21, 2024: +The specification of `--apply_t5_attn_mask` has been changed. Previously, the T5 output was zero-padded, but now, two steps are taken: "1. Apply mask when encoding T5" and "2. Apply mask in the attention of Double Block". Fine tuning, LoRA training, and inference in `flux_mini_inference.py` have been changed. + Aug 20, 2024 (update 3): __Experimental__ The multi-resolution training is now supported with caching latents to disk. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index b09f6380..5b8aa250 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -70,12 +70,22 @@ def denoise( vec: torch.Tensor, timesteps: list[float], guidance: float = 4.0, + t5_attn_mask: Optional[torch.Tensor] = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) - pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) img = img + (t_prev - t_curr) * pred @@ -92,6 +102,7 @@ def do_sample( txt_ids: torch.Tensor, num_steps: int, guidance: float, + t5_attn_mask: Optional[torch.Tensor], is_schnell: bool, device: torch.device, flux_dtype: torch.dtype, @@ -101,10 +112,14 @@ def do_sample( # denoise initial noise if accelerator: with accelerator.autocast(), torch.no_grad(): - x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance) + x = denoise( + model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + ) else: with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad(): - x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance) + x = denoise( + model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + ) return x @@ -156,14 +171,14 @@ def generate_image( clip_l.to(clip_l_dtype) t5xxl.to(t5xxl_dtype) with accelerator.autocast(): - _, t5_out, txt_ids = encoding_strategy.encode_tokens( + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask ) else: with torch.autocast(device_type=device.type, dtype=clip_l_dtype): - l_pooled, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): - _, t5_out, txt_ids = encoding_strategy.encode_tokens( + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask ) @@ -186,7 +201,11 @@ def generate_image( steps = 4 if is_schnell else 50 img_ids = img_ids.to(device) - x = do_sample(accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, is_schnell, device, flux_dtype) + t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None + + x = do_sample( + accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, t5_attn_mask, is_schnell, device, flux_dtype + ) if args.offload: model = model.cpu() # del model diff --git a/flux_train.py b/flux_train.py index 66996385..ecb8a108 100644 --- a/flux_train.py +++ b/flux_train.py @@ -610,7 +610,10 @@ def train(args): guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) # call model - l_pooled, t5_out, txt_ids = text_encoder_conds + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + with accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = flux( @@ -621,6 +624,7 @@ def train(args): y=l_pooled, timesteps=timesteps / 1000, guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, ) # unpack latents diff --git a/flux_train_network.py b/flux_train_network.py index 002252c8..49bd270c 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -233,11 +233,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.flux_lower = flux_lower self.target_device = device - def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None): + def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): self.flux_lower.to("cpu") clean_memory_on_device(self.target_device) self.flux_upper.to(self.target_device) - img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance) + img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask) self.flux_upper.to("cpu") clean_memory_on_device(self.target_device) self.flux_lower.to(self.target_device) @@ -300,10 +300,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): guidance_vec.requires_grad_(True) # Predict the noise residual - l_pooled, t5_out, txt_ids = text_encoder_conds - # print( - # f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}" - # ) + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None if not args.split_mode: # normal forward @@ -317,6 +316,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): y=l_pooled, timesteps=timesteps / 1000, guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, ) else: # split forward to reduce memory usage @@ -337,6 +337,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): y=l_pooled, timesteps=timesteps / 1000, guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, ) # move flux upper back to cpu, and then move flux lower to gpu diff --git a/library/flux_models.py b/library/flux_models.py index 11ef647a..6f28da60 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -440,10 +440,10 @@ configs = { # region math -def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: q, k = apply_rope(q, k, pe) - x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) x = rearrange(x, "B H L D -> B L (H D)") return x @@ -607,11 +607,7 @@ class SelfAttention(nn.Module): self.norm = QKNorm(head_dim) self.proj = nn.Linear(dim, dim) - # self.gradient_checkpointing = False - - # def enable_gradient_checkpointing(self): - # self.gradient_checkpointing = True - + # this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly def forward(self, x: Tensor, pe: Tensor) -> Tensor: qkv = self.qkv(x) q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) @@ -620,12 +616,6 @@ class SelfAttention(nn.Module): x = self.proj(x) return x - # def forward(self, *args, **kwargs): - # if self.training and self.gradient_checkpointing: - # return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) - # else: - # return self._forward(*args, **kwargs) - @dataclass class ModulationOut: @@ -690,7 +680,9 @@ class DoubleStreamBlock(nn.Module): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False - def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + def _forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) @@ -713,7 +705,18 @@ class DoubleStreamBlock(nn.Module): k = torch.cat((txt_k, img_k), dim=2) v = torch.cat((txt_v, img_v), dim=2) - attn = attention(q, k, v, pe=pe) + # make attention mask if not None + attn_mask = None + if txt_attention_mask is not None: + attn_mask = txt_attention_mask # b, seq_len + attn_mask = torch.cat( + (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1]).to(attn_mask.device)), dim=1 + ) # b, seq_len + img_len + + # broadcast attn_mask to all heads + attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) + + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img blocks @@ -725,10 +728,12 @@ class DoubleStreamBlock(nn.Module): txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) return img, txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + def forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: if self.training and self.gradient_checkpointing: if not self.cpu_offload_checkpointing: - return checkpoint(self._forward, img, txt, vec, pe, use_reentrant=False) + return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False) # cpu offload checkpointing def create_custom_forward(func): @@ -739,10 +744,10 @@ class DoubleStreamBlock(nn.Module): return custom_forward - return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe) + return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask) else: - return self._forward(img, txt, vec, pe) + return self._forward(img, txt, vec, pe, txt_attention_mask) # def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): # if self.training and self.gradient_checkpointing: @@ -992,6 +997,7 @@ class Flux(nn.Module): timesteps: Tensor, y: Tensor, guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -1011,7 +1017,7 @@ class Flux(nn.Module): if not self.double_blocks_to_swap: for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning for block_idx in range(self.double_blocks_to_swap): @@ -1033,7 +1039,7 @@ class Flux(nn.Module): block.to(self.device) # move to cuda # print(f"Moved double block {block_idx} to cuda.") - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) if moving: self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) @@ -1164,6 +1170,7 @@ class FluxUpper(nn.Module): timesteps: Tensor, y: Tensor, guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -1182,7 +1189,7 @@ class FluxUpper(nn.Module): pe = self.pe_embedder(ids) for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) return img, txt, vec, pe diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 3f9e8660..1d3f80d7 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -190,9 +190,10 @@ def sample_image_inference( te_outputs = sample_prompts_te_outputs[prompt] else: tokens_and_masks = tokenize_strategy.tokenize(prompt) + # strategy has apply_t5_attn_mask option te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) - l_pooled, t5_out, txt_ids = te_outputs + l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs # sample image weight_dtype = ae.dtype # TOFO give dtype as argument @@ -208,9 +209,10 @@ def sample_image_inference( ) timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) + t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None with accelerator.autocast(), torch.no_grad(): - x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale) + x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask) x = x.float() x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) @@ -289,12 +291,22 @@ def denoise( vec: torch.Tensor, timesteps: list[float], guidance: float = 4.0, + t5_attn_mask: Optional[torch.Tensor] = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) - pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) img = img + (t_prev - t_curr) * pred @@ -498,7 +510,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--apply_t5_attn_mask", action="store_true", - help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", + help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する", ) parser.add_argument( "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 5c620f3d..737af390 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -64,22 +64,25 @@ class FluxTextEncodingStrategy(TextEncodingStrategy): l_tokens, t5_tokens = tokens[:2] t5_attn_mask = tokens[2] if len(tokens) > 2 else None + # clip_l is None when using T5 only if clip_l is not None and l_tokens is not None: l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"] else: l_pooled = None + # t5xxl is None when using CLIP only if t5xxl is not None and t5_tokens is not None: # t5_out is [b, max length, 4096] - t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True) - if apply_t5_attn_mask: - t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) + attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device) + t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True) + # if zero_pad_t5_output: + # t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device) else: t5_out = None txt_ids = None - return [l_pooled, t5_out, txt_ids] + return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): @@ -115,6 +118,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): return False if "txt_ids" not in npz: return False + if "t5_attn_mask" not in npz: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -129,12 +134,12 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): l_pooled = data["l_pooled"] t5_out = data["t5_out"] txt_ids = data["txt_ids"] + t5_attn_mask = data["t5_attn_mask"] if self.apply_t5_attn_mask: - t5_attn_mask = data["t5_attn_mask"] t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) - return [l_pooled, t5_out, txt_ids] + return [l_pooled, t5_out, txt_ids, t5_attn_mask] def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List @@ -145,7 +150,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): # attn_mask is not applied when caching to disk: it is applied when loading from disk - l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens( + l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens( tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk ) @@ -159,15 +164,15 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): l_pooled = l_pooled.cpu().numpy() t5_out = t5_out.cpu().numpy() txt_ids = txt_ids.cpu().numpy() + t5_attn_mask = tokens_and_masks[2].cpu().numpy() for i, info in enumerate(infos): l_pooled_i = l_pooled[i] t5_out_i = t5_out[i] txt_ids_i = txt_ids[i] + t5_attn_mask_i = t5_attn_mask[i] if self.cache_to_disk: - t5_attn_mask = tokens_and_masks[2] - t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() np.savez( info.text_encoder_outputs_npz, l_pooled=l_pooled_i, @@ -176,7 +181,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): t5_attn_mask=t5_attn_mask_i, ) else: - info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i) + info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) class FluxLatentsCachingStrategy(LatentsCachingStrategy):