mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Update T5 attention mask handling in FLUX
This commit is contained in:
@@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
|||||||
The command to install PyTorch is as follows:
|
The command to install PyTorch is as follows:
|
||||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||||
|
|
||||||
|
Aug 21, 2024:
|
||||||
|
The specification of `--apply_t5_attn_mask` has been changed. Previously, the T5 output was zero-padded, but now, two steps are taken: "1. Apply mask when encoding T5" and "2. Apply mask in the attention of Double Block". Fine tuning, LoRA training, and inference in `flux_mini_inference.py` have been changed.
|
||||||
|
|
||||||
Aug 20, 2024 (update 3):
|
Aug 20, 2024 (update 3):
|
||||||
__Experimental__ The multi-resolution training is now supported with caching latents to disk.
|
__Experimental__ The multi-resolution training is now supported with caching latents to disk.
|
||||||
|
|
||||||
|
|||||||
@@ -70,12 +70,22 @@ def denoise(
|
|||||||
vec: torch.Tensor,
|
vec: torch.Tensor,
|
||||||
timesteps: list[float],
|
timesteps: list[float],
|
||||||
guidance: float = 4.0,
|
guidance: float = 4.0,
|
||||||
|
t5_attn_mask: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
# this is ignored for schnell
|
# this is ignored for schnell
|
||||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec)
|
pred = model(
|
||||||
|
img=img,
|
||||||
|
img_ids=img_ids,
|
||||||
|
txt=txt,
|
||||||
|
txt_ids=txt_ids,
|
||||||
|
y=vec,
|
||||||
|
timesteps=t_vec,
|
||||||
|
guidance=guidance_vec,
|
||||||
|
txt_attention_mask=t5_attn_mask,
|
||||||
|
)
|
||||||
|
|
||||||
img = img + (t_prev - t_curr) * pred
|
img = img + (t_prev - t_curr) * pred
|
||||||
|
|
||||||
@@ -92,6 +102,7 @@ def do_sample(
|
|||||||
txt_ids: torch.Tensor,
|
txt_ids: torch.Tensor,
|
||||||
num_steps: int,
|
num_steps: int,
|
||||||
guidance: float,
|
guidance: float,
|
||||||
|
t5_attn_mask: Optional[torch.Tensor],
|
||||||
is_schnell: bool,
|
is_schnell: bool,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
flux_dtype: torch.dtype,
|
flux_dtype: torch.dtype,
|
||||||
@@ -101,10 +112,14 @@ def do_sample(
|
|||||||
# denoise initial noise
|
# denoise initial noise
|
||||||
if accelerator:
|
if accelerator:
|
||||||
with accelerator.autocast(), torch.no_grad():
|
with accelerator.autocast(), torch.no_grad():
|
||||||
x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance)
|
x = denoise(
|
||||||
|
model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
|
with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
|
||||||
x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance)
|
x = denoise(
|
||||||
|
model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask
|
||||||
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -156,14 +171,14 @@ def generate_image(
|
|||||||
clip_l.to(clip_l_dtype)
|
clip_l.to(clip_l_dtype)
|
||||||
t5xxl.to(t5xxl_dtype)
|
t5xxl.to(t5xxl_dtype)
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
_, t5_out, txt_ids = encoding_strategy.encode_tokens(
|
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
||||||
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
|
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
|
||||||
l_pooled, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
||||||
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
|
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
|
||||||
_, t5_out, txt_ids = encoding_strategy.encode_tokens(
|
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
||||||
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -186,7 +201,11 @@ def generate_image(
|
|||||||
steps = 4 if is_schnell else 50
|
steps = 4 if is_schnell else 50
|
||||||
|
|
||||||
img_ids = img_ids.to(device)
|
img_ids = img_ids.to(device)
|
||||||
x = do_sample(accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, is_schnell, device, flux_dtype)
|
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
|
||||||
|
|
||||||
|
x = do_sample(
|
||||||
|
accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, t5_attn_mask, is_schnell, device, flux_dtype
|
||||||
|
)
|
||||||
if args.offload:
|
if args.offload:
|
||||||
model = model.cpu()
|
model = model.cpu()
|
||||||
# del model
|
# del model
|
||||||
|
|||||||
@@ -610,7 +610,10 @@ def train(args):
|
|||||||
guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device)
|
guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device)
|
||||||
|
|
||||||
# call model
|
# call model
|
||||||
l_pooled, t5_out, txt_ids = text_encoder_conds
|
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||||
|
if not args.apply_t5_attn_mask:
|
||||||
|
t5_attn_mask = None
|
||||||
|
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||||
model_pred = flux(
|
model_pred = flux(
|
||||||
@@ -621,6 +624,7 @@ def train(args):
|
|||||||
y=l_pooled,
|
y=l_pooled,
|
||||||
timesteps=timesteps / 1000,
|
timesteps=timesteps / 1000,
|
||||||
guidance=guidance_vec,
|
guidance=guidance_vec,
|
||||||
|
txt_attention_mask=t5_attn_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# unpack latents
|
# unpack latents
|
||||||
|
|||||||
@@ -233,11 +233,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
self.flux_lower = flux_lower
|
self.flux_lower = flux_lower
|
||||||
self.target_device = device
|
self.target_device = device
|
||||||
|
|
||||||
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None):
|
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
|
||||||
self.flux_lower.to("cpu")
|
self.flux_lower.to("cpu")
|
||||||
clean_memory_on_device(self.target_device)
|
clean_memory_on_device(self.target_device)
|
||||||
self.flux_upper.to(self.target_device)
|
self.flux_upper.to(self.target_device)
|
||||||
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance)
|
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
|
||||||
self.flux_upper.to("cpu")
|
self.flux_upper.to("cpu")
|
||||||
clean_memory_on_device(self.target_device)
|
clean_memory_on_device(self.target_device)
|
||||||
self.flux_lower.to(self.target_device)
|
self.flux_lower.to(self.target_device)
|
||||||
@@ -300,10 +300,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
guidance_vec.requires_grad_(True)
|
guidance_vec.requires_grad_(True)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
l_pooled, t5_out, txt_ids = text_encoder_conds
|
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||||
# print(
|
if not args.apply_t5_attn_mask:
|
||||||
# f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}"
|
t5_attn_mask = None
|
||||||
# )
|
|
||||||
|
|
||||||
if not args.split_mode:
|
if not args.split_mode:
|
||||||
# normal forward
|
# normal forward
|
||||||
@@ -317,6 +316,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
y=l_pooled,
|
y=l_pooled,
|
||||||
timesteps=timesteps / 1000,
|
timesteps=timesteps / 1000,
|
||||||
guidance=guidance_vec,
|
guidance=guidance_vec,
|
||||||
|
txt_attention_mask=t5_attn_mask,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# split forward to reduce memory usage
|
# split forward to reduce memory usage
|
||||||
@@ -337,6 +337,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
y=l_pooled,
|
y=l_pooled,
|
||||||
timesteps=timesteps / 1000,
|
timesteps=timesteps / 1000,
|
||||||
guidance=guidance_vec,
|
guidance=guidance_vec,
|
||||||
|
txt_attention_mask=t5_attn_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# move flux upper back to cpu, and then move flux lower to gpu
|
# move flux upper back to cpu, and then move flux lower to gpu
|
||||||
|
|||||||
@@ -440,10 +440,10 @@ configs = {
|
|||||||
# region math
|
# region math
|
||||||
|
|
||||||
|
|
||||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
|
||||||
q, k = apply_rope(q, k, pe)
|
q, k = apply_rope(q, k, pe)
|
||||||
|
|
||||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
x = rearrange(x, "B H L D -> B L (H D)")
|
x = rearrange(x, "B H L D -> B L (H D)")
|
||||||
|
|
||||||
return x
|
return x
|
||||||
@@ -607,11 +607,7 @@ class SelfAttention(nn.Module):
|
|||||||
self.norm = QKNorm(head_dim)
|
self.norm = QKNorm(head_dim)
|
||||||
self.proj = nn.Linear(dim, dim)
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
# self.gradient_checkpointing = False
|
# this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly
|
||||||
|
|
||||||
# def enable_gradient_checkpointing(self):
|
|
||||||
# self.gradient_checkpointing = True
|
|
||||||
|
|
||||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||||
qkv = self.qkv(x)
|
qkv = self.qkv(x)
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
@@ -620,12 +616,6 @@ class SelfAttention(nn.Module):
|
|||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
# def forward(self, *args, **kwargs):
|
|
||||||
# if self.training and self.gradient_checkpointing:
|
|
||||||
# return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
|
||||||
# else:
|
|
||||||
# return self._forward(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModulationOut:
|
class ModulationOut:
|
||||||
@@ -690,7 +680,9 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
self.cpu_offload_checkpointing = False
|
self.cpu_offload_checkpointing = False
|
||||||
|
|
||||||
def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
def _forward(
|
||||||
|
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
|
||||||
|
) -> tuple[Tensor, Tensor]:
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||||
|
|
||||||
@@ -713,7 +705,18 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
k = torch.cat((txt_k, img_k), dim=2)
|
k = torch.cat((txt_k, img_k), dim=2)
|
||||||
v = torch.cat((txt_v, img_v), dim=2)
|
v = torch.cat((txt_v, img_v), dim=2)
|
||||||
|
|
||||||
attn = attention(q, k, v, pe=pe)
|
# make attention mask if not None
|
||||||
|
attn_mask = None
|
||||||
|
if txt_attention_mask is not None:
|
||||||
|
attn_mask = txt_attention_mask # b, seq_len
|
||||||
|
attn_mask = torch.cat(
|
||||||
|
(attn_mask, torch.ones(attn_mask.shape[0], img.shape[1]).to(attn_mask.device)), dim=1
|
||||||
|
) # b, seq_len + img_len
|
||||||
|
|
||||||
|
# broadcast attn_mask to all heads
|
||||||
|
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
|
||||||
|
|
||||||
|
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||||
|
|
||||||
# calculate the img blocks
|
# calculate the img blocks
|
||||||
@@ -725,10 +728,12 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||||
return img, txt
|
return img, txt
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
def forward(
|
||||||
|
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
|
||||||
|
) -> tuple[Tensor, Tensor]:
|
||||||
if self.training and self.gradient_checkpointing:
|
if self.training and self.gradient_checkpointing:
|
||||||
if not self.cpu_offload_checkpointing:
|
if not self.cpu_offload_checkpointing:
|
||||||
return checkpoint(self._forward, img, txt, vec, pe, use_reentrant=False)
|
return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False)
|
||||||
# cpu offload checkpointing
|
# cpu offload checkpointing
|
||||||
|
|
||||||
def create_custom_forward(func):
|
def create_custom_forward(func):
|
||||||
@@ -739,10 +744,10 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe)
|
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return self._forward(img, txt, vec, pe)
|
return self._forward(img, txt, vec, pe, txt_attention_mask)
|
||||||
|
|
||||||
# def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
# def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
||||||
# if self.training and self.gradient_checkpointing:
|
# if self.training and self.gradient_checkpointing:
|
||||||
@@ -992,6 +997,7 @@ class Flux(nn.Module):
|
|||||||
timesteps: Tensor,
|
timesteps: Tensor,
|
||||||
y: Tensor,
|
y: Tensor,
|
||||||
guidance: Tensor | None = None,
|
guidance: Tensor | None = None,
|
||||||
|
txt_attention_mask: Tensor | None = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
@@ -1011,7 +1017,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
if not self.double_blocks_to_swap:
|
if not self.double_blocks_to_swap:
|
||||||
for block in self.double_blocks:
|
for block in self.double_blocks:
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
else:
|
else:
|
||||||
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
|
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
|
||||||
for block_idx in range(self.double_blocks_to_swap):
|
for block_idx in range(self.double_blocks_to_swap):
|
||||||
@@ -1033,7 +1039,7 @@ class Flux(nn.Module):
|
|||||||
block.to(self.device) # move to cuda
|
block.to(self.device) # move to cuda
|
||||||
# print(f"Moved double block {block_idx} to cuda.")
|
# print(f"Moved double block {block_idx} to cuda.")
|
||||||
|
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
|
|
||||||
if moving:
|
if moving:
|
||||||
self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
|
self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
|
||||||
@@ -1164,6 +1170,7 @@ class FluxUpper(nn.Module):
|
|||||||
timesteps: Tensor,
|
timesteps: Tensor,
|
||||||
y: Tensor,
|
y: Tensor,
|
||||||
guidance: Tensor | None = None,
|
guidance: Tensor | None = None,
|
||||||
|
txt_attention_mask: Tensor | None = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
@@ -1182,7 +1189,7 @@ class FluxUpper(nn.Module):
|
|||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
for block in self.double_blocks:
|
for block in self.double_blocks:
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
|
|
||||||
return img, txt, vec, pe
|
return img, txt, vec, pe
|
||||||
|
|
||||||
|
|||||||
@@ -190,9 +190,10 @@ def sample_image_inference(
|
|||||||
te_outputs = sample_prompts_te_outputs[prompt]
|
te_outputs = sample_prompts_te_outputs[prompt]
|
||||||
else:
|
else:
|
||||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||||
|
# strategy has apply_t5_attn_mask option
|
||||||
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||||
|
|
||||||
l_pooled, t5_out, txt_ids = te_outputs
|
l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs
|
||||||
|
|
||||||
# sample image
|
# sample image
|
||||||
weight_dtype = ae.dtype # TOFO give dtype as argument
|
weight_dtype = ae.dtype # TOFO give dtype as argument
|
||||||
@@ -208,9 +209,10 @@ def sample_image_inference(
|
|||||||
)
|
)
|
||||||
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
|
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
|
||||||
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
||||||
|
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
|
||||||
|
|
||||||
with accelerator.autocast(), torch.no_grad():
|
with accelerator.autocast(), torch.no_grad():
|
||||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale)
|
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask)
|
||||||
|
|
||||||
x = x.float()
|
x = x.float()
|
||||||
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||||
@@ -289,12 +291,22 @@ def denoise(
|
|||||||
vec: torch.Tensor,
|
vec: torch.Tensor,
|
||||||
timesteps: list[float],
|
timesteps: list[float],
|
||||||
guidance: float = 4.0,
|
guidance: float = 4.0,
|
||||||
|
t5_attn_mask: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
# this is ignored for schnell
|
# this is ignored for schnell
|
||||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec)
|
pred = model(
|
||||||
|
img=img,
|
||||||
|
img_ids=img_ids,
|
||||||
|
txt=txt,
|
||||||
|
txt_ids=txt_ids,
|
||||||
|
y=vec,
|
||||||
|
timesteps=t_vec,
|
||||||
|
guidance=guidance_vec,
|
||||||
|
txt_attention_mask=t5_attn_mask,
|
||||||
|
)
|
||||||
|
|
||||||
img = img + (t_prev - t_curr) * pred
|
img = img + (t_prev - t_curr) * pred
|
||||||
|
|
||||||
@@ -498,7 +510,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--apply_t5_attn_mask",
|
"--apply_t5_attn_mask",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する",
|
help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||||
|
|||||||
@@ -64,22 +64,25 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
|
|||||||
l_tokens, t5_tokens = tokens[:2]
|
l_tokens, t5_tokens = tokens[:2]
|
||||||
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
|
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
|
||||||
|
|
||||||
|
# clip_l is None when using T5 only
|
||||||
if clip_l is not None and l_tokens is not None:
|
if clip_l is not None and l_tokens is not None:
|
||||||
l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"]
|
l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"]
|
||||||
else:
|
else:
|
||||||
l_pooled = None
|
l_pooled = None
|
||||||
|
|
||||||
|
# t5xxl is None when using CLIP only
|
||||||
if t5xxl is not None and t5_tokens is not None:
|
if t5xxl is not None and t5_tokens is not None:
|
||||||
# t5_out is [b, max length, 4096]
|
# t5_out is [b, max length, 4096]
|
||||||
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True)
|
attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device)
|
||||||
if apply_t5_attn_mask:
|
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True)
|
||||||
t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
|
# if zero_pad_t5_output:
|
||||||
|
# t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
|
||||||
txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
|
txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
|
||||||
else:
|
else:
|
||||||
t5_out = None
|
t5_out = None
|
||||||
txt_ids = None
|
txt_ids = None
|
||||||
|
|
||||||
return [l_pooled, t5_out, txt_ids]
|
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer
|
||||||
|
|
||||||
|
|
||||||
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||||
@@ -115,6 +118,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
return False
|
return False
|
||||||
if "txt_ids" not in npz:
|
if "txt_ids" not in npz:
|
||||||
return False
|
return False
|
||||||
|
if "t5_attn_mask" not in npz:
|
||||||
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading file: {npz_path}")
|
logger.error(f"Error loading file: {npz_path}")
|
||||||
raise e
|
raise e
|
||||||
@@ -129,12 +134,12 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
l_pooled = data["l_pooled"]
|
l_pooled = data["l_pooled"]
|
||||||
t5_out = data["t5_out"]
|
t5_out = data["t5_out"]
|
||||||
txt_ids = data["txt_ids"]
|
txt_ids = data["txt_ids"]
|
||||||
|
t5_attn_mask = data["t5_attn_mask"]
|
||||||
|
|
||||||
if self.apply_t5_attn_mask:
|
if self.apply_t5_attn_mask:
|
||||||
t5_attn_mask = data["t5_attn_mask"]
|
|
||||||
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask)
|
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask)
|
||||||
|
|
||||||
return [l_pooled, t5_out, txt_ids]
|
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
|
||||||
|
|
||||||
def cache_batch_outputs(
|
def cache_batch_outputs(
|
||||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||||
@@ -145,7 +150,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# attn_mask is not applied when caching to disk: it is applied when loading from disk
|
# attn_mask is not applied when caching to disk: it is applied when loading from disk
|
||||||
l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens(
|
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(
|
||||||
tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk
|
tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -159,15 +164,15 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
l_pooled = l_pooled.cpu().numpy()
|
l_pooled = l_pooled.cpu().numpy()
|
||||||
t5_out = t5_out.cpu().numpy()
|
t5_out = t5_out.cpu().numpy()
|
||||||
txt_ids = txt_ids.cpu().numpy()
|
txt_ids = txt_ids.cpu().numpy()
|
||||||
|
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
|
||||||
|
|
||||||
for i, info in enumerate(infos):
|
for i, info in enumerate(infos):
|
||||||
l_pooled_i = l_pooled[i]
|
l_pooled_i = l_pooled[i]
|
||||||
t5_out_i = t5_out[i]
|
t5_out_i = t5_out[i]
|
||||||
txt_ids_i = txt_ids[i]
|
txt_ids_i = txt_ids[i]
|
||||||
|
t5_attn_mask_i = t5_attn_mask[i]
|
||||||
|
|
||||||
if self.cache_to_disk:
|
if self.cache_to_disk:
|
||||||
t5_attn_mask = tokens_and_masks[2]
|
|
||||||
t5_attn_mask_i = t5_attn_mask[i].cpu().numpy()
|
|
||||||
np.savez(
|
np.savez(
|
||||||
info.text_encoder_outputs_npz,
|
info.text_encoder_outputs_npz,
|
||||||
l_pooled=l_pooled_i,
|
l_pooled=l_pooled_i,
|
||||||
@@ -176,7 +181,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
t5_attn_mask=t5_attn_mask_i,
|
t5_attn_mask=t5_attn_mask_i,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i)
|
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
|
||||||
|
|
||||||
|
|
||||||
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
||||||
|
|||||||
Reference in New Issue
Block a user