mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support weighted captions for sdxl LoRA and fine tuning
This commit is contained in:
@@ -1123,14 +1123,21 @@ class NetworkTrainer:
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
# SD only
|
||||
encoded_text_encoder_conds = get_weighted_text_embeddings(
|
||||
tokenizers[0],
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
accelerator.device,
|
||||
args.max_token_length // 75 if args.max_token_length else 1,
|
||||
clip_skip=args.clip_skip,
|
||||
# # SD only
|
||||
# encoded_text_encoder_conds = get_weighted_text_embeddings(
|
||||
# tokenizers[0],
|
||||
# text_encoder,
|
||||
# batch["captions"],
|
||||
# accelerator.device,
|
||||
# args.max_token_length // 75 if args.max_token_length else 1,
|
||||
# clip_skip=args.clip_skip,
|
||||
# )
|
||||
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids_list,
|
||||
weights_list,
|
||||
)
|
||||
else:
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
@@ -1139,8 +1146,8 @@ class NetworkTrainer:
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids,
|
||||
)
|
||||
if args.full_fp16:
|
||||
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||
if args.full_fp16:
|
||||
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
|
||||
Reference in New Issue
Block a user