support weighted captions for sdxl LoRA and fine tuning

This commit is contained in:
Kohya S
2024-10-10 08:27:15 +09:00
parent 126159f7c4
commit 886f75345c
5 changed files with 45 additions and 35 deletions

View File

@@ -104,8 +104,8 @@ def train(args):
setup_logging(args, reset=True)
assert (
not args.weighted_captions
), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
not args.weighted_captions or not args.cache_text_encoder_outputs
), "weighted_captions is not supported when caching text encoder outputs / cache_text_encoder_outputsを使うときはweighted_captionsはサポートされていません"
assert (
not args.train_text_encoder or not args.cache_text_encoder_outputs
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
@@ -660,22 +660,24 @@ def train(args):
input_ids1, input_ids2 = batch["input_ids_list"]
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
# TODO support weighted captions
# if args.weighted_captions:
# encoder_hidden_states = get_weighted_text_embeddings(
# tokenizer,
# text_encoder,
# batch["captions"],
# accelerator.device,
# args.max_token_length // 75 if args.max_token_length else 1,
# clip_skip=args.clip_skip,
# )
# else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2]
)
if args.weighted_captions:
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoder_hidden_states1, encoder_hidden_states2, pool2 = (
text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
[text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)],
input_ids_list,
weights_list,
)
)
else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
tokenize_strategy,
[text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)],
[input_ids1, input_ids2],
)
if args.full_fp16:
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype)