support weighted captions for SD/SDXL

This commit is contained in:
Kohya S
2024-10-11 08:48:55 +09:00
parent 886f75345c
commit f2bc820133
8 changed files with 105 additions and 45 deletions

View File

@@ -323,12 +323,18 @@ class TextEncoderOutputsCachingStrategy:
_strategy = None # strategy instance: actual strategy class
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
is_weighted: bool = False,
) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
self._is_partial = is_partial
self._is_weighted = is_weighted
@classmethod
def set_strategy(cls, strategy):
@@ -352,6 +358,10 @@ class TextEncoderOutputsCachingStrategy:
def is_partial(self):
return self._is_partial
@property
def is_weighted(self):
return self._is_weighted
def get_outputs_npz_path(self, image_abs_path: str) -> str:
raise NotImplementedError