diff --git a/library/sd3_models.py b/library/sd3_models.py index a4fe400e..c19aec6a 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -56,7 +56,7 @@ class SDTokenizer: self.inv_vocab = {v: k for k, v in vocab.items()} self.max_word_length = 8 - def tokenize_with_weights(self, text: str): + def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None): """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" """ @@ -79,6 +79,14 @@ class SDTokenizer: batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) if self.min_length is not None and len(batch) < self.min_length: batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) + + # truncate to max_length + # print(f"batch: {batch}, truncate: {truncate}, len(batch): {len(batch)}, max_length: {self.max_length}") + if truncate_to_max_length and len(batch) > self.max_length: + batch = batch[: self.max_length] + if truncate_length is not None and len(batch) > truncate_length: + batch = batch[:truncate_length] + return [batch] @@ -112,10 +120,15 @@ class SD3Tokenizer: self.model_max_length = self.clip_l.max_length # 77 def tokenize_with_weights(self, text: str): + # temporary truncate to max_length even for t5xxl return ( self.clip_l.tokenize_with_weights(text), self.clip_g.tokenize_with_weights(text), - self.t5xxl.tokenize_with_weights(text) if self.t5xxl is not None else None, + ( + self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.model_max_length) + if self.t5xxl is not None + else None + ), )