mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support weighted captions for SD/SDXL
This commit is contained in:
@@ -40,6 +40,16 @@ class SdTokenizeStrategy(TokenizeStrategy):
|
||||
text = [text] if isinstance(text, str) else text
|
||||
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
|
||||
|
||||
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
tokens_list = []
|
||||
weights_list = []
|
||||
for t in text:
|
||||
tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True)
|
||||
tokens_list.append(tokens)
|
||||
weights_list.append(weights)
|
||||
return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)]
|
||||
|
||||
|
||||
class SdTextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self, clip_skip: Optional[int] = None) -> None:
|
||||
@@ -58,6 +68,8 @@ class SdTextEncodingStrategy(TextEncodingStrategy):
|
||||
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
|
||||
tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77
|
||||
|
||||
tokens = tokens.to(text_encoder.device)
|
||||
|
||||
if self.clip_skip is None:
|
||||
encoder_hidden_states = text_encoder(tokens)[0]
|
||||
else:
|
||||
@@ -93,6 +105,30 @@ class SdTextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
return [encoder_hidden_states]
|
||||
|
||||
def encode_tokens_with_weights(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens_list: List[torch.Tensor],
|
||||
weights_list: List[torch.Tensor],
|
||||
) -> List[torch.Tensor]:
|
||||
encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0]
|
||||
|
||||
weights = weights_list[0].to(encoder_hidden_states.device)
|
||||
|
||||
# apply weights
|
||||
if weights.shape[1] == 1: # no max_token_length
|
||||
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
|
||||
encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2)
|
||||
else:
|
||||
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
|
||||
for i in range(weights.shape[1]):
|
||||
encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[
|
||||
:, i, 1:-1
|
||||
].unsqueeze(-1)
|
||||
|
||||
return [encoder_hidden_states]
|
||||
|
||||
|
||||
class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
|
||||
|
||||
Reference in New Issue
Block a user