mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'sd3' into multi-gpu-caching
This commit is contained in:
@@ -1042,7 +1042,9 @@ class NetworkTrainer:
|
||||
text_encoder = None
|
||||
|
||||
# For --sample_at_first
|
||||
optimizer_eval_fn()
|
||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||
optimizer_train_fn()
|
||||
if len(accelerator.trackers) > 0:
|
||||
# log empty object to commit the sample images to wandb
|
||||
accelerator.log({}, step=0)
|
||||
@@ -1121,14 +1123,21 @@ class NetworkTrainer:
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
# SD only
|
||||
encoded_text_encoder_conds = get_weighted_text_embeddings(
|
||||
tokenizers[0],
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
accelerator.device,
|
||||
args.max_token_length // 75 if args.max_token_length else 1,
|
||||
clip_skip=args.clip_skip,
|
||||
# # SD only
|
||||
# encoded_text_encoder_conds = get_weighted_text_embeddings(
|
||||
# tokenizers[0],
|
||||
# text_encoder,
|
||||
# batch["captions"],
|
||||
# accelerator.device,
|
||||
# args.max_token_length // 75 if args.max_token_length else 1,
|
||||
# clip_skip=args.clip_skip,
|
||||
# )
|
||||
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids_list,
|
||||
weights_list,
|
||||
)
|
||||
else:
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
@@ -1137,8 +1146,8 @@ class NetworkTrainer:
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids,
|
||||
)
|
||||
if args.full_fp16:
|
||||
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||
if args.full_fp16:
|
||||
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
|
||||
Reference in New Issue
Block a user