add caching to disk for text encoder outputs

This commit is contained in:
Kohya S
2023-07-16 14:53:47 +09:00
parent 62dd99bee5
commit 516f64f4d9
6 changed files with 537 additions and 142 deletions

View File

@@ -47,7 +47,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
return args.cache_text_encoder_outputs
def cache_text_encoder_outputs_if_needed(
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset, weight_dtype
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
):
if args.cache_text_encoder_outputs:
if not args.lowram:
@@ -60,34 +60,33 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
if torch.cuda.is_available():
torch.cuda.empty_cache()
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
args, accelerator, tokenizers, text_encoders, dataset, weight_dtype
dataset.cache_text_encoder_outputs(
tokenizers,
text_encoders,
accelerator.device,
weight_dtype,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
accelerator.wait_for_everyone()
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.text_encoder1_cache = text_encoder1_cache
self.text_encoder2_cache = text_encoder2_cache
if not args.lowram:
print("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
self.text_encoder1_cache = None
self.text_encoder2_cache = None
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device)
text_encoders[1].to(accelerator.device)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
if not args.cache_text_encoder_outputs:
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
with torch.enable_grad():
# Get the text embedding for conditioning
# TODO support weighted captions
@@ -103,8 +102,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
# else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states(
args,
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizers[0],
@@ -114,19 +113,27 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
None if not args.full_fp16 else weight_dtype,
)
else:
encoder_hidden_states1 = []
encoder_hidden_states2 = []
pool2 = []
for input_id1, input_id2 in zip(input_ids1, input_ids2):
input_id1_cache_key = tuple(input_id1.flatten().tolist())
input_id2_cache_key = tuple(input_id2.flatten().tolist())
encoder_hidden_states1.append(self.text_encoder1_cache[input_id1_cache_key])
hidden_states2, p2 = self.text_encoder2_cache[input_id2_cache_key]
encoder_hidden_states2.append(hidden_states2)
pool2.append(p2)
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)
encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype)
pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype)
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
# # verify that the text encoder outputs are correct
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
# args.max_token_length,
# batch["input_ids"].to(text_encoders[0].device),
# batch["input_ids2"].to(text_encoders[0].device),
# tokenizers[0],
# tokenizers[1],
# text_encoders[0],
# text_encoders[1],
# None if not args.full_fp16 else weight_dtype,
# )
# b_size = encoder_hidden_states1.shape[0]
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# print("text encoder outputs verified")
return encoder_hidden_states1, encoder_hidden_states2, pool2