Refactor caching mechanism for latents and text encoder outputs, etc.

This commit is contained in:
Kohya S
2024-07-27 13:50:05 +09:00
parent 082f13658b
commit 41dee60383
21 changed files with 1786 additions and 733 deletions

View File

@@ -38,7 +38,7 @@ class SDTokenizer:
サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。
Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings.
"""
self.tokenizer = tokenizer
self.tokenizer: CLIPTokenizer = tokenizer
self.max_length = max_length
self.min_length = min_length
empty = self.tokenizer("")["input_ids"]
@@ -56,6 +56,19 @@ class SDTokenizer:
self.inv_vocab = {v: k for k, v in vocab.items()}
self.max_word_length = 8
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
"""
Tokenize the text without weights.
"""
if type(text) == str:
text = [text]
batch_tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")
# return tokens["input_ids"]
pad_token = self.end_token if self.pad_with_end else 0
for tokens in batch_tokens["input_ids"]:
assert tokens[0] == self.start_token, f"tokens[0]: {tokens[0]}, start_token: {self.start_token}"
def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None):
"""Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""
@@ -75,13 +88,14 @@ class SDTokenizer:
for word in to_tokenize:
batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]])
batch.append((self.end_token, 1.0))
print(len(batch), self.max_length, self.min_length)
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length:
batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
# truncate to max_length
# print(f"batch: {batch}, truncate: {truncate}, len(batch): {len(batch)}, max_length: {self.max_length}")
print(f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}")
if truncate_to_max_length and len(batch) > self.max_length:
batch = batch[: self.max_length]
if truncate_length is not None and len(batch) > truncate_length:
@@ -110,27 +124,38 @@ class SDXLClipGTokenizer(SDTokenizer):
class SD3Tokenizer:
def __init__(self, t5xxl=True):
def __init__(self, t5xxl=True, t5xxl_max_length: Optional[int] = 256):
if t5xxl_max_length is None:
t5xxl_max_length = 256
# TODO cache tokenizer settings locally or hold them in the repo like ComfyUI
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
# self.clip_l = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# self.clip_g = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
self.t5xxl = T5XXLTokenizer() if t5xxl else None
# t5xxl has 99999999 max length, clip has 77
self.model_max_length = self.clip_l.max_length # 77
self.t5xxl_max_length = t5xxl_max_length
def tokenize_with_weights(self, text: str):
# temporary truncate to max_length even for t5xxl
return (
self.clip_l.tokenize_with_weights(text),
self.clip_g.tokenize_with_weights(text),
(
self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.model_max_length)
self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.t5xxl_max_length)
if self.t5xxl is not None
else None
),
)
def tokenize(self, text: str):
return (
self.clip_l.tokenize(text),
self.clip_g.tokenize(text),
(self.t5xxl.tokenize(text) if self.t5xxl is not None else None),
)
# endregion
@@ -1474,7 +1499,10 @@ class ClipTokenWeightEncoder:
tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0]
list_of_tokens.append(tokens)
else:
list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]]
if isinstance(list_of_token_weight_pairs[0], torch.Tensor):
list_of_tokens = [list(list_of_token_weight_pairs[0])]
else:
list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]]
out, pooled = self(list_of_tokens)
if has_batch:
@@ -1614,9 +1642,9 @@ class T5XXLModel(SDClipModel):
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
#################################################################################################
"""
class T5XXLTokenizer(SDTokenizer):
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
""Wraps the T5 Tokenizer from HF into the SDTokenizer interface""
def __init__(self):
super().__init__(
@@ -1627,6 +1655,7 @@ class T5XXLTokenizer(SDTokenizer):
max_length=99999999,
min_length=77,
)
"""
class T5LayerNorm(torch.nn.Module):