mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
workaround for long caption ref #1382
This commit is contained in:
@@ -56,7 +56,7 @@ class SDTokenizer:
|
|||||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||||
self.max_word_length = 8
|
self.max_word_length = 8
|
||||||
|
|
||||||
def tokenize_with_weights(self, text: str):
|
def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None):
|
||||||
"""Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
|
"""Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
|
||||||
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""
|
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""
|
||||||
"""
|
"""
|
||||||
@@ -79,6 +79,14 @@ class SDTokenizer:
|
|||||||
batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))
|
batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))
|
||||||
if self.min_length is not None and len(batch) < self.min_length:
|
if self.min_length is not None and len(batch) < self.min_length:
|
||||||
batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
|
batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
|
||||||
|
|
||||||
|
# truncate to max_length
|
||||||
|
# print(f"batch: {batch}, truncate: {truncate}, len(batch): {len(batch)}, max_length: {self.max_length}")
|
||||||
|
if truncate_to_max_length and len(batch) > self.max_length:
|
||||||
|
batch = batch[: self.max_length]
|
||||||
|
if truncate_length is not None and len(batch) > truncate_length:
|
||||||
|
batch = batch[:truncate_length]
|
||||||
|
|
||||||
return [batch]
|
return [batch]
|
||||||
|
|
||||||
|
|
||||||
@@ -112,10 +120,15 @@ class SD3Tokenizer:
|
|||||||
self.model_max_length = self.clip_l.max_length # 77
|
self.model_max_length = self.clip_l.max_length # 77
|
||||||
|
|
||||||
def tokenize_with_weights(self, text: str):
|
def tokenize_with_weights(self, text: str):
|
||||||
|
# temporary truncate to max_length even for t5xxl
|
||||||
return (
|
return (
|
||||||
self.clip_l.tokenize_with_weights(text),
|
self.clip_l.tokenize_with_weights(text),
|
||||||
self.clip_g.tokenize_with_weights(text),
|
self.clip_g.tokenize_with_weights(text),
|
||||||
self.t5xxl.tokenize_with_weights(text) if self.t5xxl is not None else None,
|
(
|
||||||
|
self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.model_max_length)
|
||||||
|
if self.t5xxl is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user