mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
refactor SD3 CLIP to transformers etc.
This commit is contained in:
@@ -3,7 +3,7 @@ import glob
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
|
||||
|
||||
from library import sd3_utils, train_util
|
||||
from library import sd3_models
|
||||
@@ -48,45 +48,79 @@ class Sd3TokenizeStrategy(TokenizeStrategy):
|
||||
|
||||
|
||||
class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
def __init__(self, apply_lg_attn_mask: Optional[bool] = None, apply_t5_attn_mask: Optional[bool] = None) -> None:
|
||||
"""
|
||||
Args:
|
||||
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
|
||||
"""
|
||||
self.apply_lg_attn_mask = apply_lg_attn_mask
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
|
||||
def encode_tokens(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: List[torch.Tensor],
|
||||
apply_lg_attn_mask: bool = False,
|
||||
apply_t5_attn_mask: bool = False,
|
||||
apply_lg_attn_mask: Optional[bool] = False,
|
||||
apply_t5_attn_mask: Optional[bool] = False,
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
returned embeddings are not masked
|
||||
"""
|
||||
clip_l, clip_g, t5xxl = models
|
||||
clip_l: CLIPTextModel
|
||||
clip_g: CLIPTextModelWithProjection
|
||||
t5xxl: T5EncoderModel
|
||||
|
||||
if apply_lg_attn_mask is None:
|
||||
apply_lg_attn_mask = self.apply_lg_attn_mask
|
||||
if apply_t5_attn_mask is None:
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
l_tokens, g_tokens, t5_tokens = tokens[:3]
|
||||
l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] if len(tokens) > 3 else [None, None, None]
|
||||
|
||||
if len(tokens) > 3:
|
||||
l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:]
|
||||
if not apply_lg_attn_mask:
|
||||
l_attn_mask = None
|
||||
g_attn_mask = None
|
||||
else:
|
||||
l_attn_mask = l_attn_mask.to(clip_l.device)
|
||||
g_attn_mask = g_attn_mask.to(clip_g.device)
|
||||
if not apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
else:
|
||||
t5_attn_mask = t5_attn_mask.to(t5xxl.device)
|
||||
else:
|
||||
l_attn_mask = None
|
||||
g_attn_mask = None
|
||||
t5_attn_mask = None
|
||||
|
||||
if l_tokens is None:
|
||||
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
|
||||
lg_out = None
|
||||
lg_pooled = None
|
||||
else:
|
||||
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
||||
l_out, l_pooled = clip_l(l_tokens)
|
||||
g_out, g_pooled = clip_g(g_tokens)
|
||||
if apply_lg_attn_mask:
|
||||
l_out = l_out * l_attn_mask.to(l_out.device).unsqueeze(-1)
|
||||
g_out = g_out * g_attn_mask.to(g_out.device).unsqueeze(-1)
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
with torch.no_grad():
|
||||
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
||||
prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True)
|
||||
l_pooled = prompt_embeds[0]
|
||||
l_out = prompt_embeds.hidden_states[-2]
|
||||
|
||||
prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True)
|
||||
g_pooled = prompt_embeds[0]
|
||||
g_out = prompt_embeds.hidden_states[-2]
|
||||
|
||||
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
|
||||
if t5xxl is not None and t5_tokens is not None:
|
||||
t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096]
|
||||
if apply_t5_attn_mask:
|
||||
t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
|
||||
with torch.no_grad():
|
||||
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True)
|
||||
else:
|
||||
t5_out = None
|
||||
|
||||
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
|
||||
return [lg_out, t5_out, lg_pooled]
|
||||
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] # masks are used for attention masking in transformer
|
||||
|
||||
def concat_encodings(
|
||||
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
|
||||
@@ -132,39 +166,38 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
return False
|
||||
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
|
||||
return False
|
||||
# t5xxl is optional
|
||||
if "apply_lg_attn_mask" not in npz:
|
||||
return False
|
||||
if "t5_out" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
|
||||
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
|
||||
return False
|
||||
if "apply_t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
||||
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def mask_lg_attn(self, lg_out: np.ndarray, l_attn_mask: np.ndarray, g_attn_mask: np.ndarray) -> np.ndarray:
|
||||
l_out = lg_out[..., :768]
|
||||
g_out = lg_out[..., 768:] # 1280
|
||||
l_out = l_out * np.expand_dims(l_attn_mask, -1) # l_out = l_out * l_attn_mask.
|
||||
g_out = g_out * np.expand_dims(g_attn_mask, -1) # g_out = g_out * g_attn_mask.
|
||||
return np.concatenate([l_out, g_out], axis=-1)
|
||||
|
||||
def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray:
|
||||
return t5_out * np.expand_dims(t5_attn_mask, -1)
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
lg_out = data["lg_out"]
|
||||
lg_pooled = data["lg_pooled"]
|
||||
t5_out = data["t5_out"] if "t5_out" in data else None
|
||||
t5_out = data["t5_out"]
|
||||
|
||||
if self.apply_lg_attn_mask:
|
||||
l_attn_mask = data["clip_l_attn_mask"]
|
||||
g_attn_mask = data["clip_g_attn_mask"]
|
||||
lg_out = self.mask_lg_attn(lg_out, l_attn_mask, g_attn_mask)
|
||||
l_attn_mask = data["clip_l_attn_mask"]
|
||||
g_attn_mask = data["clip_g_attn_mask"]
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
|
||||
if self.apply_t5_attn_mask and t5_out is not None:
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask)
|
||||
|
||||
return [lg_out, t5_out, lg_pooled]
|
||||
# apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask
|
||||
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
@@ -174,7 +207,7 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
lg_out, t5_out, lg_pooled = sd3_text_encoding_strategy.encode_tokens(
|
||||
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask
|
||||
)
|
||||
|
||||
@@ -182,38 +215,41 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
lg_out = lg_out.float()
|
||||
if lg_pooled.dtype == torch.bfloat16:
|
||||
lg_pooled = lg_pooled.float()
|
||||
if t5_out is not None and t5_out.dtype == torch.bfloat16:
|
||||
if t5_out.dtype == torch.bfloat16:
|
||||
t5_out = t5_out.float()
|
||||
|
||||
lg_out = lg_out.cpu().numpy()
|
||||
lg_pooled = lg_pooled.cpu().numpy()
|
||||
if t5_out is not None:
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
|
||||
l_attn_mask = tokens_and_masks[3].cpu().numpy()
|
||||
g_attn_mask = tokens_and_masks[4].cpu().numpy()
|
||||
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
lg_out_i = lg_out[i]
|
||||
t5_out_i = t5_out[i] if t5_out is not None else None
|
||||
t5_out_i = t5_out[i]
|
||||
lg_pooled_i = lg_pooled[i]
|
||||
l_attn_mask_i = l_attn_mask[i]
|
||||
g_attn_mask_i = g_attn_mask[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
apply_lg_attn_mask = self.apply_lg_attn_mask
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
if self.cache_to_disk:
|
||||
clip_l_attn_mask, clip_g_attn_mask, t5_attn_mask = tokens_and_masks[3:6]
|
||||
clip_l_attn_mask_i = clip_l_attn_mask[i].cpu().numpy()
|
||||
clip_g_attn_mask_i = clip_g_attn_mask[i].cpu().numpy()
|
||||
t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() if t5_attn_mask is not None else None # shouldn't be None
|
||||
kwargs = {}
|
||||
if t5_out is not None:
|
||||
kwargs["t5_out"] = t5_out_i
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
lg_out=lg_out_i,
|
||||
lg_pooled=lg_pooled_i,
|
||||
clip_l_attn_mask=clip_l_attn_mask_i,
|
||||
clip_g_attn_mask=clip_g_attn_mask_i,
|
||||
t5_out=t5_out_i,
|
||||
clip_l_attn_mask=l_attn_mask_i,
|
||||
clip_g_attn_mask=g_attn_mask_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
**kwargs,
|
||||
apply_lg_attn_mask=apply_lg_attn_mask,
|
||||
apply_t5_attn_mask=apply_t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i)
|
||||
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
|
||||
|
||||
|
||||
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
||||
@@ -246,41 +282,3 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test code for Sd3TokenizeStrategy
|
||||
# tokenizer = sd3_models.SD3Tokenizer()
|
||||
strategy = Sd3TokenizeStrategy(256)
|
||||
text = "hello world"
|
||||
|
||||
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
|
||||
# print(l_tokens.shape)
|
||||
print(l_tokens)
|
||||
print(g_tokens)
|
||||
print(t5_tokens)
|
||||
|
||||
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
|
||||
l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||
g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||
t5_tokens_2 = strategy.t5xxl(
|
||||
texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
print(l_tokens_2)
|
||||
print(g_tokens_2)
|
||||
print(t5_tokens_2)
|
||||
|
||||
# compare
|
||||
print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0]))
|
||||
print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0]))
|
||||
print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0]))
|
||||
|
||||
text = ",".join(["hello world! this is long text"] * 50)
|
||||
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
|
||||
print(l_tokens)
|
||||
print(g_tokens)
|
||||
print(t5_tokens)
|
||||
|
||||
print(f"model max length l: {strategy.clip_l.model_max_length}")
|
||||
print(f"model max length g: {strategy.clip_g.model_max_length}")
|
||||
print(f"model max length t5: {strategy.t5xxl.model_max_length}")
|
||||
|
||||
Reference in New Issue
Block a user