sample generation in SDXL ControlNet training

This commit is contained in:
Kohya S
2024-09-30 23:39:32 +09:00
parent d78f6a775c
commit 793999d116
5 changed files with 322 additions and 165 deletions

View File

@@ -37,6 +37,22 @@ class SdxlTokenizeStrategy(TokenizeStrategy):
torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0),
)
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
text = [text] if isinstance(text, str) else text
tokens1_list, tokens2_list = [], []
weights1_list, weights2_list = [], []
for t in text:
tokens1, weights1 = self._get_weighted_input_ids(self.tokenizer1, t, self.max_length)
tokens2, weights2 = self._get_weighted_input_ids(self.tokenizer2, t, self.max_length)
tokens1_list.append(tokens1)
tokens2_list.append(tokens2)
weights1_list.append(weights1)
weights2_list.append(weights2)
return (torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)), (
torch.stack(weights1_list, dim=0),
torch.stack(weights2_list, dim=0),
)
class SdxlTextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
@@ -98,7 +114,10 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
):
# input_ids: b,n,77 -> b*n, 77
b_size = input_ids1.size()[0]
max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
if input_ids1.size()[1] == 1:
max_token_length = None
else:
max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
input_ids1 = input_ids1.to(text_encoder1.device)
@@ -172,6 +191,24 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
)
return [hidden_states1, hidden_states2, pool2]
def encode_tokens_with_weights(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
) -> List[torch.Tensor]:
hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens)
# apply weights
if weights[0].shape[1] == 1: # no max_token_length
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
hidden_states1 = hidden_states1 * weights[0].squeeze(1).unsqueeze(2)
hidden_states2 = hidden_states2 * weights[1].squeeze(1).unsqueeze(2)
else:
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
for weight, hidden_states in zip(weights, [hidden_states1, hidden_states2]):
for i in range(weight.shape[1]):
hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[:, i, 1:-1]
return [hidden_states1, hidden_states2, pool2]
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"