add caching to disk for text encoder outputs

This commit is contained in:
Kohya S
2023-07-16 14:53:47 +09:00
parent 62dd99bee5
commit 516f64f4d9
6 changed files with 537 additions and 142 deletions

View File

@@ -204,10 +204,6 @@ def train(args):
text_encoder2.gradient_checkpointing_enable()
training_models.append(text_encoder1)
training_models.append(text_encoder2)
text_encoder1_cache = None
text_encoder2_cache = None
# set require_grad=True later
else:
text_encoder1.requires_grad_(False)
@@ -218,9 +214,15 @@ def train(args):
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataset_group, None
)
with torch.no_grad():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer1, tokenizer2),
(text_encoder1, text_encoder2),
accelerator.device,
None,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
accelerator.wait_for_everyone()
if not cache_latents:
@@ -375,11 +377,10 @@ def train(args):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
b_size = latents.shape[0]
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
if not args.cache_text_encoder_outputs:
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
# TODO support weighted captions
@@ -395,8 +396,8 @@ def train(args):
# else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states(
args,
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizer1,
@@ -406,19 +407,26 @@ def train(args):
None if not args.full_fp16 else weight_dtype,
)
else:
encoder_hidden_states1 = []
encoder_hidden_states2 = []
pool2 = []
for input_id1, input_id2 in zip(input_ids1, input_ids2):
input_id1_cache_key = tuple(input_id1.squeeze(0).flatten().tolist())
input_id2_cache_key = tuple(input_id2.squeeze(0).flatten().tolist())
encoder_hidden_states1.append(text_encoder1_cache[input_id1_cache_key])
hidden_states2, p2 = text_encoder2_cache[input_id2_cache_key]
encoder_hidden_states2.append(hidden_states2)
pool2.append(p2)
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)
encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype)
pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype)
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
# # verify that the text encoder outputs are correct
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
# args.max_token_length,
# batch["input_ids"].to(text_encoder1.device),
# batch["input_ids2"].to(text_encoder1.device),
# tokenizer1,
# tokenizer2,
# text_encoder1,
# text_encoder2,
# None if not args.full_fp16 else weight_dtype,
# )
# b_size = encoder_hidden_states1.shape[0]
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# print("text encoder outputs verified")
# get size embeddings
orig_size = batch["original_sizes_hw"]