Merge branch 'sd3' into multi-gpu-caching

This commit is contained in:
kohya-ss
2024-10-12 16:39:36 +09:00
24 changed files with 1790 additions and 352 deletions

View File

@@ -1042,7 +1042,9 @@ class NetworkTrainer:
text_encoder = None
# For --sample_at_first
optimizer_eval_fn()
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
optimizer_train_fn()
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
@@ -1121,14 +1123,21 @@ class NetworkTrainer:
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
# SD only
encoded_text_encoder_conds = get_weighted_text_embeddings(
tokenizers[0],
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
# # SD only
# encoded_text_encoder_conds = get_weighted_text_embeddings(
# tokenizers[0],
# text_encoder,
# batch["captions"],
# accelerator.device,
# args.max_token_length // 75 if args.max_token_length else 1,
# clip_skip=args.clip_skip,
# )
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids_list,
weights_list,
)
else:
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
@@ -1137,8 +1146,8 @@ class NetworkTrainer:
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
)
if args.full_fp16:
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
if args.full_fp16:
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0: