refactor SD3 CLIP to transformers etc.

This commit is contained in:
Kohya S
2024-10-24 19:49:28 +09:00
parent 138dac4aea
commit 623017f716
13 changed files with 1201 additions and 2150 deletions

View File

@@ -3,7 +3,7 @@ import glob
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
from library import sd3_utils, train_util
from library import sd3_models
@@ -48,45 +48,79 @@ class Sd3TokenizeStrategy(TokenizeStrategy):
class Sd3TextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
pass
def __init__(self, apply_lg_attn_mask: Optional[bool] = None, apply_t5_attn_mask: Optional[bool] = None) -> None:
"""
Args:
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
"""
self.apply_lg_attn_mask = apply_lg_attn_mask
self.apply_t5_attn_mask = apply_t5_attn_mask
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_lg_attn_mask: bool = False,
apply_t5_attn_mask: bool = False,
apply_lg_attn_mask: Optional[bool] = False,
apply_t5_attn_mask: Optional[bool] = False,
) -> List[torch.Tensor]:
"""
returned embeddings are not masked
"""
clip_l, clip_g, t5xxl = models
clip_l: CLIPTextModel
clip_g: CLIPTextModelWithProjection
t5xxl: T5EncoderModel
if apply_lg_attn_mask is None:
apply_lg_attn_mask = self.apply_lg_attn_mask
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask
l_tokens, g_tokens, t5_tokens = tokens[:3]
l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] if len(tokens) > 3 else [None, None, None]
if len(tokens) > 3:
l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:]
if not apply_lg_attn_mask:
l_attn_mask = None
g_attn_mask = None
else:
l_attn_mask = l_attn_mask.to(clip_l.device)
g_attn_mask = g_attn_mask.to(clip_g.device)
if not apply_t5_attn_mask:
t5_attn_mask = None
else:
t5_attn_mask = t5_attn_mask.to(t5xxl.device)
else:
l_attn_mask = None
g_attn_mask = None
t5_attn_mask = None
if l_tokens is None:
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
lg_out = None
lg_pooled = None
else:
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
l_out, l_pooled = clip_l(l_tokens)
g_out, g_pooled = clip_g(g_tokens)
if apply_lg_attn_mask:
l_out = l_out * l_attn_mask.to(l_out.device).unsqueeze(-1)
g_out = g_out * g_attn_mask.to(g_out.device).unsqueeze(-1)
lg_out = torch.cat([l_out, g_out], dim=-1)
with torch.no_grad():
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True)
l_pooled = prompt_embeds[0]
l_out = prompt_embeds.hidden_states[-2]
prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True)
g_pooled = prompt_embeds[0]
g_out = prompt_embeds.hidden_states[-2]
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
lg_out = torch.cat([l_out, g_out], dim=-1)
if t5xxl is not None and t5_tokens is not None:
t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096]
if apply_t5_attn_mask:
t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
with torch.no_grad():
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True)
else:
t5_out = None
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
return [lg_out, t5_out, lg_pooled]
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] # masks are used for attention masking in transformer
def concat_encodings(
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
@@ -132,39 +166,38 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return False
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
return False
# t5xxl is optional
if "apply_lg_attn_mask" not in npz:
return False
if "t5_out" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def mask_lg_attn(self, lg_out: np.ndarray, l_attn_mask: np.ndarray, g_attn_mask: np.ndarray) -> np.ndarray:
l_out = lg_out[..., :768]
g_out = lg_out[..., 768:] # 1280
l_out = l_out * np.expand_dims(l_attn_mask, -1) # l_out = l_out * l_attn_mask.
g_out = g_out * np.expand_dims(g_attn_mask, -1) # g_out = g_out * g_attn_mask.
return np.concatenate([l_out, g_out], axis=-1)
def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray:
return t5_out * np.expand_dims(t5_attn_mask, -1)
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
lg_out = data["lg_out"]
lg_pooled = data["lg_pooled"]
t5_out = data["t5_out"] if "t5_out" in data else None
t5_out = data["t5_out"]
if self.apply_lg_attn_mask:
l_attn_mask = data["clip_l_attn_mask"]
g_attn_mask = data["clip_g_attn_mask"]
lg_out = self.mask_lg_attn(lg_out, l_attn_mask, g_attn_mask)
l_attn_mask = data["clip_l_attn_mask"]
g_attn_mask = data["clip_g_attn_mask"]
t5_attn_mask = data["t5_attn_mask"]
if self.apply_t5_attn_mask and t5_out is not None:
t5_attn_mask = data["t5_attn_mask"]
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask)
return [lg_out, t5_out, lg_pooled]
# apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
@@ -174,7 +207,7 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
lg_out, t5_out, lg_pooled = sd3_text_encoding_strategy.encode_tokens(
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask
)
@@ -182,38 +215,41 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
lg_out = lg_out.float()
if lg_pooled.dtype == torch.bfloat16:
lg_pooled = lg_pooled.float()
if t5_out is not None and t5_out.dtype == torch.bfloat16:
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
lg_out = lg_out.cpu().numpy()
lg_pooled = lg_pooled.cpu().numpy()
if t5_out is not None:
t5_out = t5_out.cpu().numpy()
t5_out = t5_out.cpu().numpy()
l_attn_mask = tokens_and_masks[3].cpu().numpy()
g_attn_mask = tokens_and_masks[4].cpu().numpy()
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
for i, info in enumerate(infos):
lg_out_i = lg_out[i]
t5_out_i = t5_out[i] if t5_out is not None else None
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
l_attn_mask_i = l_attn_mask[i]
g_attn_mask_i = g_attn_mask[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_lg_attn_mask = self.apply_lg_attn_mask
apply_t5_attn_mask = self.apply_t5_attn_mask
if self.cache_to_disk:
clip_l_attn_mask, clip_g_attn_mask, t5_attn_mask = tokens_and_masks[3:6]
clip_l_attn_mask_i = clip_l_attn_mask[i].cpu().numpy()
clip_g_attn_mask_i = clip_g_attn_mask[i].cpu().numpy()
t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() if t5_attn_mask is not None else None # shouldn't be None
kwargs = {}
if t5_out is not None:
kwargs["t5_out"] = t5_out_i
np.savez(
info.text_encoder_outputs_npz,
lg_out=lg_out_i,
lg_pooled=lg_pooled_i,
clip_l_attn_mask=clip_l_attn_mask_i,
clip_g_attn_mask=clip_g_attn_mask_i,
t5_out=t5_out_i,
clip_l_attn_mask=l_attn_mask_i,
clip_g_attn_mask=g_attn_mask_i,
t5_attn_mask=t5_attn_mask_i,
**kwargs,
apply_lg_attn_mask=apply_lg_attn_mask,
apply_t5_attn_mask=apply_t5_attn_mask,
)
else:
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i)
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
@@ -246,41 +282,3 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)
if __name__ == "__main__":
# test code for Sd3TokenizeStrategy
# tokenizer = sd3_models.SD3Tokenizer()
strategy = Sd3TokenizeStrategy(256)
text = "hello world"
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
# print(l_tokens.shape)
print(l_tokens)
print(g_tokens)
print(t5_tokens)
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens_2 = strategy.t5xxl(
texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
print(l_tokens_2)
print(g_tokens_2)
print(t5_tokens_2)
# compare
print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0]))
print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0]))
print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0]))
text = ",".join(["hello world! this is long text"] * 50)
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
print(l_tokens)
print(g_tokens)
print(t5_tokens)
print(f"model max length l: {strategy.clip_l.model_max_length}")
print(f"model max length g: {strategy.clip_g.model_max_length}")
print(f"model max length t5: {strategy.t5xxl.model_max_length}")