mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
support weighted captions for SD/SDXL
This commit is contained in:
17
fine_tune.py
17
fine_tune.py
@@ -366,22 +366,17 @@ def train(args):
|
|||||||
with torch.set_grad_enabled(args.train_text_encoder):
|
with torch.set_grad_enabled(args.train_text_encoder):
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
if args.weighted_captions:
|
if args.weighted_captions:
|
||||||
# TODO move to strategy_sd.py
|
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
||||||
encoder_hidden_states = get_weighted_text_embeddings(
|
encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights(
|
||||||
tokenize_strategy.tokenizer,
|
tokenize_strategy, [text_encoder], input_ids_list, weights_list
|
||||||
text_encoder,
|
)[0]
|
||||||
batch["captions"],
|
|
||||||
accelerator.device,
|
|
||||||
args.max_token_length // 75 if args.max_token_length else 1,
|
|
||||||
clip_skip=args.clip_skip,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
input_ids = batch["input_ids_list"][0].to(accelerator.device)
|
input_ids = batch["input_ids_list"][0].to(accelerator.device)
|
||||||
encoder_hidden_states = text_encoding_strategy.encode_tokens(
|
encoder_hidden_states = text_encoding_strategy.encode_tokens(
|
||||||
tokenize_strategy, [text_encoder], [input_ids]
|
tokenize_strategy, [text_encoder], [input_ids]
|
||||||
)[0]
|
)[0]
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
|
|||||||
@@ -363,9 +363,9 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
|
|||||||
# )
|
# )
|
||||||
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
||||||
|
|
||||||
assert (
|
# assert (
|
||||||
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
# not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||||
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
# ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||||
|
|
||||||
if supportTextEncoderCaching:
|
if supportTextEncoderCaching:
|
||||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||||
|
|||||||
@@ -323,12 +323,18 @@ class TextEncoderOutputsCachingStrategy:
|
|||||||
_strategy = None # strategy instance: actual strategy class
|
_strategy = None # strategy instance: actual strategy class
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
|
self,
|
||||||
|
cache_to_disk: bool,
|
||||||
|
batch_size: int,
|
||||||
|
skip_disk_cache_validity_check: bool,
|
||||||
|
is_partial: bool = False,
|
||||||
|
is_weighted: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._cache_to_disk = cache_to_disk
|
self._cache_to_disk = cache_to_disk
|
||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||||
self._is_partial = is_partial
|
self._is_partial = is_partial
|
||||||
|
self._is_weighted = is_weighted
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_strategy(cls, strategy):
|
def set_strategy(cls, strategy):
|
||||||
@@ -352,6 +358,10 @@ class TextEncoderOutputsCachingStrategy:
|
|||||||
def is_partial(self):
|
def is_partial(self):
|
||||||
return self._is_partial
|
return self._is_partial
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_weighted(self):
|
||||||
|
return self._is_weighted
|
||||||
|
|
||||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,16 @@ class SdTokenizeStrategy(TokenizeStrategy):
|
|||||||
text = [text] if isinstance(text, str) else text
|
text = [text] if isinstance(text, str) else text
|
||||||
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
|
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
|
||||||
|
text = [text] if isinstance(text, str) else text
|
||||||
|
tokens_list = []
|
||||||
|
weights_list = []
|
||||||
|
for t in text:
|
||||||
|
tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True)
|
||||||
|
tokens_list.append(tokens)
|
||||||
|
weights_list.append(weights)
|
||||||
|
return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)]
|
||||||
|
|
||||||
|
|
||||||
class SdTextEncodingStrategy(TextEncodingStrategy):
|
class SdTextEncodingStrategy(TextEncodingStrategy):
|
||||||
def __init__(self, clip_skip: Optional[int] = None) -> None:
|
def __init__(self, clip_skip: Optional[int] = None) -> None:
|
||||||
@@ -58,6 +68,8 @@ class SdTextEncodingStrategy(TextEncodingStrategy):
|
|||||||
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
|
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
|
||||||
tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77
|
tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77
|
||||||
|
|
||||||
|
tokens = tokens.to(text_encoder.device)
|
||||||
|
|
||||||
if self.clip_skip is None:
|
if self.clip_skip is None:
|
||||||
encoder_hidden_states = text_encoder(tokens)[0]
|
encoder_hidden_states = text_encoder(tokens)[0]
|
||||||
else:
|
else:
|
||||||
@@ -93,6 +105,30 @@ class SdTextEncodingStrategy(TextEncodingStrategy):
|
|||||||
|
|
||||||
return [encoder_hidden_states]
|
return [encoder_hidden_states]
|
||||||
|
|
||||||
|
def encode_tokens_with_weights(
|
||||||
|
self,
|
||||||
|
tokenize_strategy: TokenizeStrategy,
|
||||||
|
models: List[Any],
|
||||||
|
tokens_list: List[torch.Tensor],
|
||||||
|
weights_list: List[torch.Tensor],
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0]
|
||||||
|
|
||||||
|
weights = weights_list[0].to(encoder_hidden_states.device)
|
||||||
|
|
||||||
|
# apply weights
|
||||||
|
if weights.shape[1] == 1: # no max_token_length
|
||||||
|
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
|
||||||
|
encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2)
|
||||||
|
else:
|
||||||
|
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
|
||||||
|
for i in range(weights.shape[1]):
|
||||||
|
encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[
|
||||||
|
:, i, 1:-1
|
||||||
|
].unsqueeze(-1)
|
||||||
|
|
||||||
|
return [encoder_hidden_states]
|
||||||
|
|
||||||
|
|
||||||
class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||||
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
|
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
|
||||||
|
|||||||
@@ -42,16 +42,16 @@ class SdxlTokenizeStrategy(TokenizeStrategy):
|
|||||||
tokens1_list, tokens2_list = [], []
|
tokens1_list, tokens2_list = [], []
|
||||||
weights1_list, weights2_list = [], []
|
weights1_list, weights2_list = [], []
|
||||||
for t in text:
|
for t in text:
|
||||||
tokens1, weights1 = self._get_weighted_input_ids(self.tokenizer1, t, self.max_length)
|
tokens1, weights1 = self._get_input_ids(self.tokenizer1, t, self.max_length, weighted=True)
|
||||||
tokens2, weights2 = self._get_weighted_input_ids(self.tokenizer2, t, self.max_length)
|
tokens2, weights2 = self._get_input_ids(self.tokenizer2, t, self.max_length, weighted=True)
|
||||||
tokens1_list.append(tokens1)
|
tokens1_list.append(tokens1)
|
||||||
tokens2_list.append(tokens2)
|
tokens2_list.append(tokens2)
|
||||||
weights1_list.append(weights1)
|
weights1_list.append(weights1)
|
||||||
weights2_list.append(weights2)
|
weights2_list.append(weights2)
|
||||||
return (torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)), (
|
return [torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)], [
|
||||||
torch.stack(weights1_list, dim=0),
|
torch.stack(weights1_list, dim=0),
|
||||||
torch.stack(weights2_list, dim=0),
|
torch.stack(weights2_list, dim=0),
|
||||||
)
|
]
|
||||||
|
|
||||||
|
|
||||||
class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
||||||
@@ -193,20 +193,28 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
|||||||
return [hidden_states1, hidden_states2, pool2]
|
return [hidden_states1, hidden_states2, pool2]
|
||||||
|
|
||||||
def encode_tokens_with_weights(
|
def encode_tokens_with_weights(
|
||||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
|
self,
|
||||||
|
tokenize_strategy: TokenizeStrategy,
|
||||||
|
models: List[Any],
|
||||||
|
tokens_list: List[torch.Tensor],
|
||||||
|
weights_list: List[torch.Tensor],
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens)
|
hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens_list)
|
||||||
|
|
||||||
|
weights_list = [weights.to(hidden_states1.device) for weights in weights_list]
|
||||||
|
|
||||||
# apply weights
|
# apply weights
|
||||||
if weights[0].shape[1] == 1: # no max_token_length
|
if weights_list[0].shape[1] == 1: # no max_token_length
|
||||||
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
|
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
|
||||||
hidden_states1 = hidden_states1 * weights[0].squeeze(1).unsqueeze(2)
|
hidden_states1 = hidden_states1 * weights_list[0].squeeze(1).unsqueeze(2)
|
||||||
hidden_states2 = hidden_states2 * weights[1].squeeze(1).unsqueeze(2)
|
hidden_states2 = hidden_states2 * weights_list[1].squeeze(1).unsqueeze(2)
|
||||||
else:
|
else:
|
||||||
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
|
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
|
||||||
for weight, hidden_states in zip(weights, [hidden_states1, hidden_states2]):
|
for weight, hidden_states in zip(weights_list, [hidden_states1, hidden_states2]):
|
||||||
for i in range(weight.shape[1]):
|
for i in range(weight.shape[1]):
|
||||||
hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[:, i, 1:-1]
|
hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[
|
||||||
|
:, i, 1:-1
|
||||||
|
].unsqueeze(-1)
|
||||||
|
|
||||||
return [hidden_states1, hidden_states2, pool2]
|
return [hidden_states1, hidden_states2, pool2]
|
||||||
|
|
||||||
@@ -215,9 +223,14 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
|
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
|
self,
|
||||||
|
cache_to_disk: bool,
|
||||||
|
batch_size: int,
|
||||||
|
skip_disk_cache_validity_check: bool,
|
||||||
|
is_partial: bool = False,
|
||||||
|
is_weighted: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted)
|
||||||
|
|
||||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||||
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||||
@@ -253,11 +266,19 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
|
sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
|
||||||
captions = [info.caption for info in infos]
|
captions = [info.caption for info in infos]
|
||||||
|
|
||||||
tokens1, tokens2 = tokenize_strategy.tokenize(captions)
|
if self.is_weighted:
|
||||||
with torch.no_grad():
|
tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions)
|
||||||
hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens(
|
with torch.no_grad():
|
||||||
tokenize_strategy, models, [tokens1, tokens2]
|
hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens_with_weights(
|
||||||
)
|
tokenize_strategy, models, tokens_list, weights_list
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tokens1, tokens2 = tokenize_strategy.tokenize(captions)
|
||||||
|
with torch.no_grad():
|
||||||
|
hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy, models, [tokens1, tokens2]
|
||||||
|
)
|
||||||
|
|
||||||
if hidden_state1.dtype == torch.bfloat16:
|
if hidden_state1.dtype == torch.bfloat16:
|
||||||
hidden_state1 = hidden_state1.float()
|
hidden_state1 = hidden_state1.float()
|
||||||
if hidden_state2.dtype == torch.bfloat16:
|
if hidden_state2.dtype == torch.bfloat16:
|
||||||
|
|||||||
@@ -321,7 +321,7 @@ def train(args):
|
|||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
# Text Encodes are eval and no grad
|
# Text Encodes are eval and no grad
|
||||||
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||||
args.cache_text_encoder_outputs_to_disk, None, False
|
args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions
|
||||||
)
|
)
|
||||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
|
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,9 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
|
|
||||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False)
|
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||||
|
args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
16
train_db.py
16
train_db.py
@@ -356,21 +356,17 @@ def train(args):
|
|||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
||||||
if args.weighted_captions:
|
if args.weighted_captions:
|
||||||
encoder_hidden_states = get_weighted_text_embeddings(
|
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
||||||
tokenize_strategy.tokenizer,
|
encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights(
|
||||||
text_encoder,
|
tokenize_strategy, [text_encoder], input_ids_list, weights_list
|
||||||
batch["captions"],
|
)[0]
|
||||||
accelerator.device,
|
|
||||||
args.max_token_length // 75 if args.max_token_length else 1,
|
|
||||||
clip_skip=args.clip_skip,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
input_ids = batch["input_ids_list"][0].to(accelerator.device)
|
input_ids = batch["input_ids_list"][0].to(accelerator.device)
|
||||||
encoder_hidden_states = text_encoding_strategy.encode_tokens(
|
encoder_hidden_states = text_encoding_strategy.encode_tokens(
|
||||||
tokenize_strategy, [text_encoder], [input_ids]
|
tokenize_strategy, [text_encoder], [input_ids]
|
||||||
)[0]
|
)[0]
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
|
|||||||
Reference in New Issue
Block a user