support weighted captions for sdxl LoRA and fine tuning

This commit is contained in:
Kohya S
2024-10-10 08:27:15 +09:00
parent 126159f7c4
commit 886f75345c
5 changed files with 45 additions and 35 deletions

View File

@@ -74,6 +74,9 @@ class TokenizeStrategy:
raise NotImplementedError
def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
returns: [tokens1, tokens2, ...], [weights1, weights2, ...]
"""
raise NotImplementedError
def _get_weighted_input_ids(
@@ -303,7 +306,7 @@ class TextEncodingStrategy:
:return: list of output embeddings for each architecture
"""
raise NotImplementedError
def encode_tokens_with_weights(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
) -> List[torch.Tensor]: