mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support weighted captions for sdxl LoRA and fine tuning
This commit is contained in:
@@ -74,6 +74,9 @@ class TokenizeStrategy:
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
"""
|
||||
returns: [tokens1, tokens2, ...], [weights1, weights2, ...]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_weighted_input_ids(
|
||||
@@ -303,7 +306,7 @@ class TextEncodingStrategy:
|
||||
:return: list of output embeddings for each architecture
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def encode_tokens_with_weights(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
|
||||
@@ -174,7 +174,8 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
||||
"""
|
||||
Args:
|
||||
tokenize_strategy: TokenizeStrategy
|
||||
models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)]
|
||||
models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)].
|
||||
If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required
|
||||
tokens: List of tokens, for text_encoder1 and text_encoder2
|
||||
"""
|
||||
if len(models) == 2:
|
||||
|
||||
Reference in New Issue
Block a user