support attn mask for l+g/t5

This commit is contained in:
Kohya S
2024-08-05 20:51:34 +09:00
parent 231df197dd
commit da4d0fe016
4 changed files with 107 additions and 24 deletions

View File

@@ -37,11 +37,14 @@ class Sd3TokenizeStrategy(TokenizeStrategy):
g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
l_attn_mask = l_tokens["attention_mask"]
g_attn_mask = g_tokens["attention_mask"]
t5_attn_mask = t5_tokens["attention_mask"]
l_tokens = l_tokens["input_ids"]
g_tokens = g_tokens["input_ids"]
t5_tokens = t5_tokens["input_ids"]
return [l_tokens, g_tokens, t5_tokens]
return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask]
class Sd3TextEncodingStrategy(TextEncodingStrategy):
@@ -49,11 +52,20 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
pass
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_lg_attn_mask: bool = False,
apply_t5_attn_mask: bool = False,
) -> List[torch.Tensor]:
"""
returned embeddings are not masked
"""
clip_l, clip_g, t5xxl = models
l_tokens, g_tokens, t5_tokens = tokens
l_tokens, g_tokens, t5_tokens = tokens[:3]
l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] if len(tokens) > 3 else [None, None, None]
if l_tokens is None:
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
lg_out = None
@@ -61,10 +73,15 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
l_out, l_pooled = clip_l(l_tokens)
g_out, g_pooled = clip_g(g_tokens)
if apply_lg_attn_mask:
l_out = l_out * l_attn_mask.to(l_out.device).unsqueeze(-1)
g_out = g_out * g_attn_mask.to(g_out.device).unsqueeze(-1)
lg_out = torch.cat([l_out, g_out], dim=-1)
if t5xxl is not None and t5_tokens is not None:
t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096]
if apply_t5_attn_mask:
t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
else:
t5_out = None
@@ -84,50 +101,81 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_lg_attn_mask: bool = False,
apply_t5_attn_mask: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_lg_attn_mask = apply_lg_attn_mask
self.apply_t5_attn_mask = apply_t5_attn_mask
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, abs_path: str):
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(self.get_outputs_npz_path(abs_path)):
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(self.get_outputs_npz_path(abs_path))
if "clip_l" not in npz or "clip_g" not in npz:
npz = np.load(npz_path)
if "lg_out" not in npz:
return False
if "clip_l_pool" not in npz or "clip_g_pool" not in npz:
if "lg_pooled" not in npz:
return False
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
return False
# t5xxl is optional
except Exception as e:
logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}")
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def mask_lg_attn(self, lg_out: np.ndarray, l_attn_mask: np.ndarray, g_attn_mask: np.ndarray) -> np.ndarray:
l_out = lg_out[..., :768]
g_out = lg_out[..., 768:] # 1280
l_out = l_out * np.expand_dims(l_attn_mask, -1) # l_out = l_out * l_attn_mask.
g_out = g_out * np.expand_dims(g_attn_mask, -1) # g_out = g_out * g_attn_mask.
return np.concatenate([l_out, g_out], axis=-1)
def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray:
return t5_out * np.expand_dims(t5_attn_mask, -1)
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
lg_out = data["lg_out"]
lg_pooled = data["lg_pooled"]
t5_out = data["t5_out"] if "t5_out" in data else None
if self.apply_lg_attn_mask:
l_attn_mask = data["clip_l_attn_mask"]
g_attn_mask = data["clip_g_attn_mask"]
lg_out = self.mask_lg_attn(lg_out, l_attn_mask, g_attn_mask)
if self.apply_t5_attn_mask and t5_out is not None:
t5_attn_mask = data["t5_attn_mask"]
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask)
return [lg_out, t5_out, lg_pooled]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]
clip_l_tokens, clip_g_tokens, t5xxl_tokens = tokenize_strategy.tokenize(captions)
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
lg_out, t5_out, lg_pooled = text_encoding_strategy.encode_tokens(
tokenize_strategy, models, [clip_l_tokens, clip_g_tokens, t5xxl_tokens]
lg_out, t5_out, lg_pooled = sd3_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask
)
if lg_out.dtype == torch.bfloat16:
@@ -148,10 +196,22 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
lg_pooled_i = lg_pooled[i]
if self.cache_to_disk:
clip_l_attn_mask, clip_g_attn_mask, t5_attn_mask = tokens_and_masks[3:6]
clip_l_attn_mask_i = clip_l_attn_mask[i].cpu().numpy()
clip_g_attn_mask_i = clip_g_attn_mask[i].cpu().numpy()
t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() if t5_attn_mask is not None else None # shouldn't be None
kwargs = {}
if t5_out is not None:
kwargs["t5_out"] = t5_out_i
np.savez(info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, **kwargs)
np.savez(
info.text_encoder_outputs_npz,
lg_out=lg_out_i,
lg_pooled=lg_pooled_i,
clip_l_attn_mask=clip_l_attn_mask_i,
clip_g_attn_mask=clip_g_attn_mask_i,
t5_attn_mask=t5_attn_mask_i,
**kwargs,
)
else:
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i)

View File

@@ -646,7 +646,7 @@ class BaseDataset(torch.utils.data.Dataset):
# caching
self.caching_mode = None # None, 'latents', 'text'
self.tokenize_strategy = None
self.text_encoder_output_caching_strategy = None
self.latents_caching_strategy = None
@@ -1486,6 +1486,7 @@ class BaseDataset(torch.utils.data.Dataset):
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz(
image_info.text_encoder_outputs_npz
)
text_encoder_outputs = [torch.FloatTensor(x) for x in text_encoder_outputs]
else:
tokenization_required = True
text_encoder_outputs_list.append(text_encoder_outputs)