mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add caching to disk for text encoder outputs
This commit is contained in:
@@ -204,10 +204,6 @@ def train(args):
|
||||
text_encoder2.gradient_checkpointing_enable()
|
||||
training_models.append(text_encoder1)
|
||||
training_models.append(text_encoder2)
|
||||
|
||||
text_encoder1_cache = None
|
||||
text_encoder2_cache = None
|
||||
|
||||
# set require_grad=True later
|
||||
else:
|
||||
text_encoder1.requires_grad_(False)
|
||||
@@ -218,9 +214,15 @@ def train(args):
|
||||
# TextEncoderの出力をキャッシュする
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
|
||||
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataset_group, None
|
||||
)
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_text_encoder_outputs(
|
||||
(tokenizer1, tokenizer2),
|
||||
(text_encoder1, text_encoder2),
|
||||
accelerator.device,
|
||||
None,
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if not cache_latents:
|
||||
@@ -375,11 +377,10 @@ def train(args):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||
b_size = latents.shape[0]
|
||||
|
||||
input_ids1 = batch["input_ids"]
|
||||
input_ids2 = batch["input_ids2"]
|
||||
if not args.cache_text_encoder_outputs:
|
||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||
input_ids1 = batch["input_ids"]
|
||||
input_ids2 = batch["input_ids2"]
|
||||
with torch.set_grad_enabled(args.train_text_encoder):
|
||||
# Get the text embedding for conditioning
|
||||
# TODO support weighted captions
|
||||
@@ -395,8 +396,8 @@ def train(args):
|
||||
# else:
|
||||
input_ids1 = input_ids1.to(accelerator.device)
|
||||
input_ids2 = input_ids2.to(accelerator.device)
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states(
|
||||
args,
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
||||
args.max_token_length,
|
||||
input_ids1,
|
||||
input_ids2,
|
||||
tokenizer1,
|
||||
@@ -406,19 +407,26 @@ def train(args):
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_states1 = []
|
||||
encoder_hidden_states2 = []
|
||||
pool2 = []
|
||||
for input_id1, input_id2 in zip(input_ids1, input_ids2):
|
||||
input_id1_cache_key = tuple(input_id1.squeeze(0).flatten().tolist())
|
||||
input_id2_cache_key = tuple(input_id2.squeeze(0).flatten().tolist())
|
||||
encoder_hidden_states1.append(text_encoder1_cache[input_id1_cache_key])
|
||||
hidden_states2, p2 = text_encoder2_cache[input_id2_cache_key]
|
||||
encoder_hidden_states2.append(hidden_states2)
|
||||
pool2.append(p2)
|
||||
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)
|
||||
encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype)
|
||||
pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype)
|
||||
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
||||
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
|
||||
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
||||
|
||||
# # verify that the text encoder outputs are correct
|
||||
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
|
||||
# args.max_token_length,
|
||||
# batch["input_ids"].to(text_encoder1.device),
|
||||
# batch["input_ids2"].to(text_encoder1.device),
|
||||
# tokenizer1,
|
||||
# tokenizer2,
|
||||
# text_encoder1,
|
||||
# text_encoder2,
|
||||
# None if not args.full_fp16 else weight_dtype,
|
||||
# )
|
||||
# b_size = encoder_hidden_states1.shape[0]
|
||||
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# print("text encoder outputs verified")
|
||||
|
||||
# get size embeddings
|
||||
orig_size = batch["original_sizes_hw"]
|
||||
|
||||
Reference in New Issue
Block a user