workaround for long caption ref #1382

This commit is contained in:
Kohya S
2024-06-24 23:13:14 +09:00
parent 0fe4eafac9
commit 4802e4aaec

View File

@@ -56,7 +56,7 @@ class SDTokenizer:
self.inv_vocab = {v: k for k, v in vocab.items()} self.inv_vocab = {v: k for k, v in vocab.items()}
self.max_word_length = 8 self.max_word_length = 8
def tokenize_with_weights(self, text: str): def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None):
"""Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""
""" """
@@ -79,6 +79,14 @@ class SDTokenizer:
batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length: if self.min_length is not None and len(batch) < self.min_length:
batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
# truncate to max_length
# print(f"batch: {batch}, truncate: {truncate}, len(batch): {len(batch)}, max_length: {self.max_length}")
if truncate_to_max_length and len(batch) > self.max_length:
batch = batch[: self.max_length]
if truncate_length is not None and len(batch) > truncate_length:
batch = batch[:truncate_length]
return [batch] return [batch]
@@ -112,10 +120,15 @@ class SD3Tokenizer:
self.model_max_length = self.clip_l.max_length # 77 self.model_max_length = self.clip_l.max_length # 77
def tokenize_with_weights(self, text: str): def tokenize_with_weights(self, text: str):
# temporary truncate to max_length even for t5xxl
return ( return (
self.clip_l.tokenize_with_weights(text), self.clip_l.tokenize_with_weights(text),
self.clip_g.tokenize_with_weights(text), self.clip_g.tokenize_with_weights(text),
self.t5xxl.tokenize_with_weights(text) if self.t5xxl is not None else None, (
self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.model_max_length)
if self.t5xxl is not None
else None
),
) )