mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
sample generation in SDXL ControlNet training
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# base class for platform strategies. this file defines the interface for strategies
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -22,6 +23,24 @@ logger = logging.getLogger(__name__)
|
||||
class TokenizeStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
_re_attention = re.compile(
|
||||
r"""\\\(|
|
||||
\\\)|
|
||||
\\\[|
|
||||
\\]|
|
||||
\\\\|
|
||||
\\|
|
||||
\(|
|
||||
\[|
|
||||
:([+-]?[.\d]+)\)|
|
||||
\)|
|
||||
]|
|
||||
[^\\()\[\]:]+|
|
||||
:
|
||||
""",
|
||||
re.X,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
@@ -54,7 +73,151 @@ class TokenizeStrategy:
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None) -> torch.Tensor:
|
||||
def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_weighted_input_ids(
|
||||
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
max_length includes starting and ending tokens.
|
||||
"""
|
||||
|
||||
def parse_prompt_attention(text):
|
||||
"""
|
||||
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
||||
Accepted tokens are:
|
||||
(abc) - increases attention to abc by a multiplier of 1.1
|
||||
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||
[abc] - decreases attention to abc by a multiplier of 1.1
|
||||
\( - literal character '('
|
||||
\[ - literal character '['
|
||||
\) - literal character ')'
|
||||
\] - literal character ']'
|
||||
\\ - literal character '\'
|
||||
anything else - just text
|
||||
>>> parse_prompt_attention('normal text')
|
||||
[['normal text', 1.0]]
|
||||
>>> parse_prompt_attention('an (important) word')
|
||||
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
||||
>>> parse_prompt_attention('(unbalanced')
|
||||
[['unbalanced', 1.1]]
|
||||
>>> parse_prompt_attention('\(literal\]')
|
||||
[['(literal]', 1.0]]
|
||||
>>> parse_prompt_attention('(unnecessary)(parens)')
|
||||
[['unnecessaryparens', 1.1]]
|
||||
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
||||
[['a ', 1.0],
|
||||
['house', 1.5730000000000004],
|
||||
[' ', 1.1],
|
||||
['on', 1.0],
|
||||
[' a ', 1.1],
|
||||
['hill', 0.55],
|
||||
[', sun, ', 1.1],
|
||||
['sky', 1.4641000000000006],
|
||||
['.', 1.1]]
|
||||
"""
|
||||
|
||||
res = []
|
||||
round_brackets = []
|
||||
square_brackets = []
|
||||
|
||||
round_bracket_multiplier = 1.1
|
||||
square_bracket_multiplier = 1 / 1.1
|
||||
|
||||
def multiply_range(start_position, multiplier):
|
||||
for p in range(start_position, len(res)):
|
||||
res[p][1] *= multiplier
|
||||
|
||||
for m in TokenizeStrategy._re_attention.finditer(text):
|
||||
text = m.group(0)
|
||||
weight = m.group(1)
|
||||
|
||||
if text.startswith("\\"):
|
||||
res.append([text[1:], 1.0])
|
||||
elif text == "(":
|
||||
round_brackets.append(len(res))
|
||||
elif text == "[":
|
||||
square_brackets.append(len(res))
|
||||
elif weight is not None and len(round_brackets) > 0:
|
||||
multiply_range(round_brackets.pop(), float(weight))
|
||||
elif text == ")" and len(round_brackets) > 0:
|
||||
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
||||
elif text == "]" and len(square_brackets) > 0:
|
||||
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||
else:
|
||||
res.append([text, 1.0])
|
||||
|
||||
for pos in round_brackets:
|
||||
multiply_range(pos, round_bracket_multiplier)
|
||||
|
||||
for pos in square_brackets:
|
||||
multiply_range(pos, square_bracket_multiplier)
|
||||
|
||||
if len(res) == 0:
|
||||
res = [["", 1.0]]
|
||||
|
||||
# merge runs of identical weights
|
||||
i = 0
|
||||
while i + 1 < len(res):
|
||||
if res[i][1] == res[i + 1][1]:
|
||||
res[i][0] += res[i + 1][0]
|
||||
res.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return res
|
||||
|
||||
def get_prompts_with_weights(text: str, max_length: int):
|
||||
r"""
|
||||
Tokenize a list of prompts and return its tokens with weights of each token. max_length does not include starting and ending token.
|
||||
|
||||
No padding, starting or ending token is included.
|
||||
"""
|
||||
truncated = False
|
||||
|
||||
texts_and_weights = parse_prompt_attention(text)
|
||||
tokens = []
|
||||
weights = []
|
||||
for word, weight in texts_and_weights:
|
||||
# tokenize and discard the starting and the ending token
|
||||
token = tokenizer(word).input_ids[1:-1]
|
||||
tokens += token
|
||||
# copy the weight by length of token
|
||||
weights += [weight] * len(token)
|
||||
# stop if the text is too long (longer than truncation limit)
|
||||
if len(tokens) > max_length:
|
||||
truncated = True
|
||||
break
|
||||
# truncate
|
||||
if len(tokens) > max_length:
|
||||
truncated = True
|
||||
tokens = tokens[:max_length]
|
||||
weights = weights[:max_length]
|
||||
if truncated:
|
||||
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
||||
return tokens, weights
|
||||
|
||||
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad):
|
||||
r"""
|
||||
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
||||
"""
|
||||
tokens = [bos] + tokens + [eos] + [pad] * (max_length - 2 - len(tokens))
|
||||
weights = [1.0] + weights + [1.0] * (max_length - 1 - len(weights))
|
||||
return tokens, weights
|
||||
|
||||
if max_length is None:
|
||||
max_length = tokenizer.model_max_length
|
||||
|
||||
tokens, weights = get_prompts_with_weights(text, max_length - 2)
|
||||
tokens, weights = pad_tokens_and_weights(
|
||||
tokens, weights, max_length, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id
|
||||
)
|
||||
return torch.tensor(tokens).unsqueeze(0), torch.tensor(weights).unsqueeze(0)
|
||||
|
||||
def _get_input_ids(
|
||||
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None, weighted: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
for SD1.5/2.0/SDXL
|
||||
TODO support batch input
|
||||
@@ -62,7 +225,10 @@ class TokenizeStrategy:
|
||||
if max_length is None:
|
||||
max_length = tokenizer.model_max_length - 2
|
||||
|
||||
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
|
||||
if weighted:
|
||||
input_ids, weights = self._get_weighted_input_ids(tokenizer, text, max_length)
|
||||
else:
|
||||
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
|
||||
|
||||
if max_length > tokenizer.model_max_length:
|
||||
input_ids = input_ids.squeeze(0)
|
||||
@@ -101,6 +267,17 @@ class TokenizeStrategy:
|
||||
iids_list.append(ids_chunk)
|
||||
|
||||
input_ids = torch.stack(iids_list) # 3,77
|
||||
|
||||
if weighted:
|
||||
weights = weights.squeeze(0)
|
||||
new_weights = torch.ones(input_ids.shape)
|
||||
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
|
||||
b = i // (tokenizer.model_max_length - 2)
|
||||
new_weights[b, 1 : 1 + tokenizer.model_max_length - 2] = weights[i : i + tokenizer.model_max_length - 2]
|
||||
weights = new_weights
|
||||
|
||||
if weighted:
|
||||
return input_ids, weights
|
||||
return input_ids
|
||||
|
||||
|
||||
@@ -126,6 +303,17 @@ class TextEncodingStrategy:
|
||||
:return: list of output embeddings for each architecture
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_tokens_with_weights(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Encode tokens into embeddings and outputs.
|
||||
:param tokens: list of token tensors for each TextModel
|
||||
:param weights: list of weight tensors for each TextModel
|
||||
:return: list of output embeddings for each architecture
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TextEncoderOutputsCachingStrategy:
|
||||
|
||||
Reference in New Issue
Block a user