mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add caching to disk for text encoder outputs
This commit is contained in:
@@ -47,7 +47,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
return args.cache_text_encoder_outputs
|
||||
|
||||
def cache_text_encoder_outputs_if_needed(
|
||||
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset, weight_dtype
|
||||
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
||||
):
|
||||
if args.cache_text_encoder_outputs:
|
||||
if not args.lowram:
|
||||
@@ -60,34 +60,33 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
|
||||
args, accelerator, tokenizers, text_encoders, dataset, weight_dtype
|
||||
dataset.cache_text_encoder_outputs(
|
||||
tokenizers,
|
||||
text_encoders,
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||
|
||||
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||
text_encoders[1].to("cpu", dtype=torch.float32)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.text_encoder1_cache = text_encoder1_cache
|
||||
self.text_encoder2_cache = text_encoder2_cache
|
||||
|
||||
if not args.lowram:
|
||||
print("move vae and unet back to original device")
|
||||
vae.to(org_vae_device)
|
||||
unet.to(org_unet_device)
|
||||
else:
|
||||
self.text_encoder1_cache = None
|
||||
self.text_encoder2_cache = None
|
||||
|
||||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||
text_encoders[0].to(accelerator.device)
|
||||
text_encoders[1].to(accelerator.device)
|
||||
|
||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||
input_ids1 = batch["input_ids"]
|
||||
input_ids2 = batch["input_ids2"]
|
||||
if not args.cache_text_encoder_outputs:
|
||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||
input_ids1 = batch["input_ids"]
|
||||
input_ids2 = batch["input_ids2"]
|
||||
with torch.enable_grad():
|
||||
# Get the text embedding for conditioning
|
||||
# TODO support weighted captions
|
||||
@@ -103,8 +102,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
# else:
|
||||
input_ids1 = input_ids1.to(accelerator.device)
|
||||
input_ids2 = input_ids2.to(accelerator.device)
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states(
|
||||
args,
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
||||
args.max_token_length,
|
||||
input_ids1,
|
||||
input_ids2,
|
||||
tokenizers[0],
|
||||
@@ -114,19 +113,27 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_states1 = []
|
||||
encoder_hidden_states2 = []
|
||||
pool2 = []
|
||||
for input_id1, input_id2 in zip(input_ids1, input_ids2):
|
||||
input_id1_cache_key = tuple(input_id1.flatten().tolist())
|
||||
input_id2_cache_key = tuple(input_id2.flatten().tolist())
|
||||
encoder_hidden_states1.append(self.text_encoder1_cache[input_id1_cache_key])
|
||||
hidden_states2, p2 = self.text_encoder2_cache[input_id2_cache_key]
|
||||
encoder_hidden_states2.append(hidden_states2)
|
||||
pool2.append(p2)
|
||||
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)
|
||||
encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype)
|
||||
pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype)
|
||||
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
||||
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
|
||||
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
||||
|
||||
# # verify that the text encoder outputs are correct
|
||||
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
|
||||
# args.max_token_length,
|
||||
# batch["input_ids"].to(text_encoders[0].device),
|
||||
# batch["input_ids2"].to(text_encoders[0].device),
|
||||
# tokenizers[0],
|
||||
# tokenizers[1],
|
||||
# text_encoders[0],
|
||||
# text_encoders[1],
|
||||
# None if not args.full_fp16 else weight_dtype,
|
||||
# )
|
||||
# b_size = encoder_hidden_states1.shape[0]
|
||||
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# print("text encoder outputs verified")
|
||||
|
||||
|
||||
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||
|
||||
|
||||
Reference in New Issue
Block a user