mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Update T5 attention mask handling in FLUX
This commit is contained in:
@@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
||||
The command to install PyTorch is as follows:
|
||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||
|
||||
Aug 21, 2024:
|
||||
The specification of `--apply_t5_attn_mask` has been changed. Previously, the T5 output was zero-padded, but now, two steps are taken: "1. Apply mask when encoding T5" and "2. Apply mask in the attention of Double Block". Fine tuning, LoRA training, and inference in `flux_mini_inference.py` have been changed.
|
||||
|
||||
Aug 20, 2024 (update 3):
|
||||
__Experimental__ The multi-resolution training is now supported with caching latents to disk.
|
||||
|
||||
|
||||
@@ -70,12 +70,22 @@ def denoise(
|
||||
vec: torch.Tensor,
|
||||
timesteps: list[float],
|
||||
guidance: float = 4.0,
|
||||
t5_attn_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
@@ -92,6 +102,7 @@ def do_sample(
|
||||
txt_ids: torch.Tensor,
|
||||
num_steps: int,
|
||||
guidance: float,
|
||||
t5_attn_mask: Optional[torch.Tensor],
|
||||
is_schnell: bool,
|
||||
device: torch.device,
|
||||
flux_dtype: torch.dtype,
|
||||
@@ -101,10 +112,14 @@ def do_sample(
|
||||
# denoise initial noise
|
||||
if accelerator:
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance)
|
||||
x = denoise(
|
||||
model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask
|
||||
)
|
||||
else:
|
||||
with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
|
||||
x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance)
|
||||
x = denoise(
|
||||
model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
@@ -156,14 +171,14 @@ def generate_image(
|
||||
clip_l.to(clip_l_dtype)
|
||||
t5xxl.to(t5xxl_dtype)
|
||||
with accelerator.autocast():
|
||||
_, t5_out, txt_ids = encoding_strategy.encode_tokens(
|
||||
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||
)
|
||||
else:
|
||||
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
|
||||
l_pooled, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
||||
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
||||
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
|
||||
_, t5_out, txt_ids = encoding_strategy.encode_tokens(
|
||||
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||
)
|
||||
|
||||
@@ -186,7 +201,11 @@ def generate_image(
|
||||
steps = 4 if is_schnell else 50
|
||||
|
||||
img_ids = img_ids.to(device)
|
||||
x = do_sample(accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, is_schnell, device, flux_dtype)
|
||||
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
|
||||
|
||||
x = do_sample(
|
||||
accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, t5_attn_mask, is_schnell, device, flux_dtype
|
||||
)
|
||||
if args.offload:
|
||||
model = model.cpu()
|
||||
# del model
|
||||
|
||||
@@ -610,7 +610,10 @@ def train(args):
|
||||
guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device)
|
||||
|
||||
# call model
|
||||
l_pooled, t5_out, txt_ids = text_encoder_conds
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
with accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = flux(
|
||||
@@ -621,6 +624,7 @@ def train(args):
|
||||
y=l_pooled,
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
# unpack latents
|
||||
|
||||
@@ -233,11 +233,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.flux_lower = flux_lower
|
||||
self.target_device = device
|
||||
|
||||
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None):
|
||||
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
|
||||
self.flux_lower.to("cpu")
|
||||
clean_memory_on_device(self.target_device)
|
||||
self.flux_upper.to(self.target_device)
|
||||
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance)
|
||||
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
|
||||
self.flux_upper.to("cpu")
|
||||
clean_memory_on_device(self.target_device)
|
||||
self.flux_lower.to(self.target_device)
|
||||
@@ -300,10 +300,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
guidance_vec.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
l_pooled, t5_out, txt_ids = text_encoder_conds
|
||||
# print(
|
||||
# f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}"
|
||||
# )
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
if not args.split_mode:
|
||||
# normal forward
|
||||
@@ -317,6 +316,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
y=l_pooled,
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
# split forward to reduce memory usage
|
||||
@@ -337,6 +337,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
y=l_pooled,
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
# move flux upper back to cpu, and then move flux lower to gpu
|
||||
|
||||
@@ -440,10 +440,10 @@ configs = {
|
||||
# region math
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
@@ -607,11 +607,7 @@ class SelfAttention(nn.Module):
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
# self.gradient_checkpointing = False
|
||||
|
||||
# def enable_gradient_checkpointing(self):
|
||||
# self.gradient_checkpointing = True
|
||||
|
||||
# this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly
|
||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
@@ -620,12 +616,6 @@ class SelfAttention(nn.Module):
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
# def forward(self, *args, **kwargs):
|
||||
# if self.training and self.gradient_checkpointing:
|
||||
# return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
||||
# else:
|
||||
# return self._forward(*args, **kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
@@ -690,7 +680,9 @@ class DoubleStreamBlock(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
||||
def _forward(
|
||||
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
@@ -713,7 +705,18 @@ class DoubleStreamBlock(nn.Module):
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
# make attention mask if not None
|
||||
attn_mask = None
|
||||
if txt_attention_mask is not None:
|
||||
attn_mask = txt_attention_mask # b, seq_len
|
||||
attn_mask = torch.cat(
|
||||
(attn_mask, torch.ones(attn_mask.shape[0], img.shape[1]).to(attn_mask.device)), dim=1
|
||||
) # b, seq_len + img_len
|
||||
|
||||
# broadcast attn_mask to all heads
|
||||
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
|
||||
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img blocks
|
||||
@@ -725,10 +728,12 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
return img, txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
||||
def forward(
|
||||
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if not self.cpu_offload_checkpointing:
|
||||
return checkpoint(self._forward, img, txt, vec, pe, use_reentrant=False)
|
||||
return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False)
|
||||
# cpu offload checkpointing
|
||||
|
||||
def create_custom_forward(func):
|
||||
@@ -739,10 +744,10 @@ class DoubleStreamBlock(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe)
|
||||
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask)
|
||||
|
||||
else:
|
||||
return self._forward(img, txt, vec, pe)
|
||||
return self._forward(img, txt, vec, pe, txt_attention_mask)
|
||||
|
||||
# def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
||||
# if self.training and self.gradient_checkpointing:
|
||||
@@ -992,6 +997,7 @@ class Flux(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -1011,7 +1017,7 @@ class Flux(nn.Module):
|
||||
|
||||
if not self.double_blocks_to_swap:
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
else:
|
||||
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
|
||||
for block_idx in range(self.double_blocks_to_swap):
|
||||
@@ -1033,7 +1039,7 @@ class Flux(nn.Module):
|
||||
block.to(self.device) # move to cuda
|
||||
# print(f"Moved double block {block_idx} to cuda.")
|
||||
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
|
||||
if moving:
|
||||
self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
|
||||
@@ -1164,6 +1170,7 @@ class FluxUpper(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -1182,7 +1189,7 @@ class FluxUpper(nn.Module):
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
|
||||
return img, txt, vec, pe
|
||||
|
||||
|
||||
@@ -190,9 +190,10 @@ def sample_image_inference(
|
||||
te_outputs = sample_prompts_te_outputs[prompt]
|
||||
else:
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
|
||||
l_pooled, t5_out, txt_ids = te_outputs
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs
|
||||
|
||||
# sample image
|
||||
weight_dtype = ae.dtype # TOFO give dtype as argument
|
||||
@@ -208,9 +209,10 @@ def sample_image_inference(
|
||||
)
|
||||
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
|
||||
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
|
||||
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale)
|
||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask)
|
||||
|
||||
x = x.float()
|
||||
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
@@ -289,12 +291,22 @@ def denoise(
|
||||
vec: torch.Tensor,
|
||||
timesteps: list[float],
|
||||
guidance: float = 4.0,
|
||||
t5_attn_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
@@ -498,7 +510,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--apply_t5_attn_mask",
|
||||
action="store_true",
|
||||
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する",
|
||||
help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||
|
||||
@@ -64,22 +64,25 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
|
||||
l_tokens, t5_tokens = tokens[:2]
|
||||
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
|
||||
|
||||
# clip_l is None when using T5 only
|
||||
if clip_l is not None and l_tokens is not None:
|
||||
l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"]
|
||||
else:
|
||||
l_pooled = None
|
||||
|
||||
# t5xxl is None when using CLIP only
|
||||
if t5xxl is not None and t5_tokens is not None:
|
||||
# t5_out is [b, max length, 4096]
|
||||
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True)
|
||||
if apply_t5_attn_mask:
|
||||
t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
|
||||
attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device)
|
||||
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True)
|
||||
# if zero_pad_t5_output:
|
||||
# t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
|
||||
txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
|
||||
else:
|
||||
t5_out = None
|
||||
txt_ids = None
|
||||
|
||||
return [l_pooled, t5_out, txt_ids]
|
||||
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer
|
||||
|
||||
|
||||
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
@@ -115,6 +118,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
return False
|
||||
if "txt_ids" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
@@ -129,12 +134,12 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
l_pooled = data["l_pooled"]
|
||||
t5_out = data["t5_out"]
|
||||
txt_ids = data["txt_ids"]
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
|
||||
if self.apply_t5_attn_mask:
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask)
|
||||
|
||||
return [l_pooled, t5_out, txt_ids]
|
||||
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
@@ -145,7 +150,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
# attn_mask is not applied when caching to disk: it is applied when loading from disk
|
||||
l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens(
|
||||
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk
|
||||
)
|
||||
|
||||
@@ -159,15 +164,15 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
l_pooled = l_pooled.cpu().numpy()
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
txt_ids = txt_ids.cpu().numpy()
|
||||
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
l_pooled_i = l_pooled[i]
|
||||
t5_out_i = t5_out[i]
|
||||
txt_ids_i = txt_ids[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
t5_attn_mask = tokens_and_masks[2]
|
||||
t5_attn_mask_i = t5_attn_mask[i].cpu().numpy()
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
l_pooled=l_pooled_i,
|
||||
@@ -176,7 +181,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i)
|
||||
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
|
||||
|
||||
|
||||
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
|
||||
Reference in New Issue
Block a user