mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
sample generation in SDXL ControlNet training
This commit is contained in:
@@ -37,6 +37,22 @@ class SdxlTokenizeStrategy(TokenizeStrategy):
|
||||
torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0),
|
||||
)
|
||||
|
||||
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
tokens1_list, tokens2_list = [], []
|
||||
weights1_list, weights2_list = [], []
|
||||
for t in text:
|
||||
tokens1, weights1 = self._get_weighted_input_ids(self.tokenizer1, t, self.max_length)
|
||||
tokens2, weights2 = self._get_weighted_input_ids(self.tokenizer2, t, self.max_length)
|
||||
tokens1_list.append(tokens1)
|
||||
tokens2_list.append(tokens2)
|
||||
weights1_list.append(weights1)
|
||||
weights2_list.append(weights2)
|
||||
return (torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)), (
|
||||
torch.stack(weights1_list, dim=0),
|
||||
torch.stack(weights2_list, dim=0),
|
||||
)
|
||||
|
||||
|
||||
class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self) -> None:
|
||||
@@ -98,7 +114,10 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
||||
):
|
||||
# input_ids: b,n,77 -> b*n, 77
|
||||
b_size = input_ids1.size()[0]
|
||||
max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
|
||||
if input_ids1.size()[1] == 1:
|
||||
max_token_length = None
|
||||
else:
|
||||
max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
|
||||
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
|
||||
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
|
||||
input_ids1 = input_ids1.to(text_encoder1.device)
|
||||
@@ -172,6 +191,24 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
||||
)
|
||||
return [hidden_states1, hidden_states2, pool2]
|
||||
|
||||
def encode_tokens_with_weights(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens)
|
||||
|
||||
# apply weights
|
||||
if weights[0].shape[1] == 1: # no max_token_length
|
||||
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
|
||||
hidden_states1 = hidden_states1 * weights[0].squeeze(1).unsqueeze(2)
|
||||
hidden_states2 = hidden_states2 * weights[1].squeeze(1).unsqueeze(2)
|
||||
else:
|
||||
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
|
||||
for weight, hidden_states in zip(weights, [hidden_states1, hidden_states2]):
|
||||
for i in range(weight.shape[1]):
|
||||
hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[:, i, 1:-1]
|
||||
|
||||
return [hidden_states1, hidden_states2, pool2]
|
||||
|
||||
|
||||
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
|
||||
|
||||
Reference in New Issue
Block a user