mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Refactor caching mechanism for latents and text encoder outputs, etc.
This commit is contained in:
@@ -38,7 +38,7 @@ class SDTokenizer:
|
||||
サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。
|
||||
Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings.
|
||||
"""
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer: CLIPTokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.min_length = min_length
|
||||
empty = self.tokenizer("")["input_ids"]
|
||||
@@ -56,6 +56,19 @@ class SDTokenizer:
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
self.max_word_length = 8
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
"""
|
||||
Tokenize the text without weights.
|
||||
"""
|
||||
if type(text) == str:
|
||||
text = [text]
|
||||
batch_tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")
|
||||
# return tokens["input_ids"]
|
||||
|
||||
pad_token = self.end_token if self.pad_with_end else 0
|
||||
for tokens in batch_tokens["input_ids"]:
|
||||
assert tokens[0] == self.start_token, f"tokens[0]: {tokens[0]}, start_token: {self.start_token}"
|
||||
|
||||
def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None):
|
||||
"""Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
|
||||
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""
|
||||
@@ -75,13 +88,14 @@ class SDTokenizer:
|
||||
for word in to_tokenize:
|
||||
batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]])
|
||||
batch.append((self.end_token, 1.0))
|
||||
print(len(batch), self.max_length, self.min_length)
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))
|
||||
if self.min_length is not None and len(batch) < self.min_length:
|
||||
batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
|
||||
|
||||
# truncate to max_length
|
||||
# print(f"batch: {batch}, truncate: {truncate}, len(batch): {len(batch)}, max_length: {self.max_length}")
|
||||
print(f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}")
|
||||
if truncate_to_max_length and len(batch) > self.max_length:
|
||||
batch = batch[: self.max_length]
|
||||
if truncate_length is not None and len(batch) > truncate_length:
|
||||
@@ -110,27 +124,38 @@ class SDXLClipGTokenizer(SDTokenizer):
|
||||
|
||||
|
||||
class SD3Tokenizer:
|
||||
def __init__(self, t5xxl=True):
|
||||
def __init__(self, t5xxl=True, t5xxl_max_length: Optional[int] = 256):
|
||||
if t5xxl_max_length is None:
|
||||
t5xxl_max_length = 256
|
||||
|
||||
# TODO cache tokenizer settings locally or hold them in the repo like ComfyUI
|
||||
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
|
||||
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
|
||||
# self.clip_l = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
# self.clip_g = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
self.t5xxl = T5XXLTokenizer() if t5xxl else None
|
||||
# t5xxl has 99999999 max length, clip has 77
|
||||
self.model_max_length = self.clip_l.max_length # 77
|
||||
self.t5xxl_max_length = t5xxl_max_length
|
||||
|
||||
def tokenize_with_weights(self, text: str):
|
||||
# temporary truncate to max_length even for t5xxl
|
||||
return (
|
||||
self.clip_l.tokenize_with_weights(text),
|
||||
self.clip_g.tokenize_with_weights(text),
|
||||
(
|
||||
self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.model_max_length)
|
||||
self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.t5xxl_max_length)
|
||||
if self.t5xxl is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def tokenize(self, text: str):
|
||||
return (
|
||||
self.clip_l.tokenize(text),
|
||||
self.clip_g.tokenize(text),
|
||||
(self.t5xxl.tokenize(text) if self.t5xxl is not None else None),
|
||||
)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -1474,7 +1499,10 @@ class ClipTokenWeightEncoder:
|
||||
tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0]
|
||||
list_of_tokens.append(tokens)
|
||||
else:
|
||||
list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]]
|
||||
if isinstance(list_of_token_weight_pairs[0], torch.Tensor):
|
||||
list_of_tokens = [list(list_of_token_weight_pairs[0])]
|
||||
else:
|
||||
list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]]
|
||||
|
||||
out, pooled = self(list_of_tokens)
|
||||
if has_batch:
|
||||
@@ -1614,9 +1642,9 @@ class T5XXLModel(SDClipModel):
|
||||
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
class T5XXLTokenizer(SDTokenizer):
|
||||
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
|
||||
""Wraps the T5 Tokenizer from HF into the SDTokenizer interface""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -1627,6 +1655,7 @@ class T5XXLTokenizer(SDTokenizer):
|
||||
max_length=99999999,
|
||||
min_length=77,
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
class T5LayerNorm(torch.nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user