mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add caching to disk for text encoder outputs
This commit is contained in:
@@ -104,6 +104,8 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||
@@ -122,6 +124,11 @@ class ImageInfo:
|
||||
self.latents_crop_left_top: Tuple[int, int] = None # original image crop left top, not latents crop left top
|
||||
self.cond_img_path: str = None
|
||||
self.image: Optional[Image.Image] = None # optional, original PIL Image
|
||||
# SDXL, optional
|
||||
self.text_encoder_outputs_npz: Optional[str] = None
|
||||
self.text_encoder_outputs1: Optional[torch.Tensor] = None
|
||||
self.text_encoder_outputs2: Optional[torch.Tensor] = None
|
||||
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class BucketManager:
|
||||
@@ -793,7 +800,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
# ちょっと速くした
|
||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
print("caching latents.")
|
||||
|
||||
image_infos = list(self.image_data.values())
|
||||
@@ -841,9 +848,73 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
return
|
||||
|
||||
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
|
||||
print("caching latents...")
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop)
|
||||
|
||||
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
||||
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
|
||||
# SD1/2に対応するにはv2のフラグを持つ必要があるので後回し
|
||||
def cache_text_encoder_outputs(
|
||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||
):
|
||||
assert len(tokenizers) == 2, "only support SDXL"
|
||||
|
||||
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
|
||||
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
print("caching text encoder outputs.")
|
||||
image_infos = list(self.image_data.values())
|
||||
|
||||
print("checking cache existence...")
|
||||
image_infos_to_cache = []
|
||||
for info in tqdm(image_infos):
|
||||
# subset = self.image_to_subset[info.image_key]
|
||||
if cache_to_disk:
|
||||
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
|
||||
info.text_encoder_outputs_npz = te_out_npz
|
||||
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
|
||||
if os.path.exists(te_out_npz):
|
||||
continue
|
||||
|
||||
image_infos_to_cache.append(info)
|
||||
|
||||
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
|
||||
return
|
||||
|
||||
# prepare tokenizers and text encoders
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(device)
|
||||
if weight_dtype is not None:
|
||||
text_encoder.to(dtype=weight_dtype)
|
||||
|
||||
# create batch
|
||||
batch = []
|
||||
batches = []
|
||||
for info in image_infos_to_cache:
|
||||
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
|
||||
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
|
||||
batch.append((info, input_ids1, input_ids2))
|
||||
|
||||
if len(batch) >= self.batch_size:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
|
||||
# iterate batches: call text encoder and cache outputs for memory or disk
|
||||
print("caching text encoder outputs...")
|
||||
for batch in tqdm(batches):
|
||||
infos, input_ids1, input_ids2 = zip(*batch)
|
||||
input_ids1 = torch.stack(input_ids1, dim=0)
|
||||
input_ids2 = torch.stack(input_ids2, dim=0)
|
||||
cache_batch_text_encoder_outputs(
|
||||
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype
|
||||
)
|
||||
|
||||
def get_image_size(self, image_path):
|
||||
image = Image.open(image_path)
|
||||
return image.size
|
||||
@@ -931,6 +1002,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
crop_top_lefts = []
|
||||
target_sizes_hw = []
|
||||
flippeds = [] # 変数名が微妙
|
||||
text_encoder_outputs1_list = []
|
||||
text_encoder_outputs2_list = []
|
||||
text_encoder_pool2_list = []
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
@@ -1012,44 +1086,76 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
target_sizes_hw.append((target_size[1], target_size[0]))
|
||||
flippeds.append(flipped)
|
||||
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
if self.XTI_layers:
|
||||
caption_layer = []
|
||||
for layer in self.XTI_layers:
|
||||
token_strings_from = " ".join(self.token_strings)
|
||||
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||
caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||
caption_layer.append(caption_)
|
||||
captions.append(caption_layer)
|
||||
else:
|
||||
# captionとtext encoder outputを処理する
|
||||
caption = image_info.caption # default
|
||||
if image_info.text_encoder_outputs1 is not None:
|
||||
text_encoder_outputs1_list.append(image_info.text_encoder_outputs1)
|
||||
text_encoder_outputs2_list.append(image_info.text_encoder_outputs2)
|
||||
text_encoder_pool2_list.append(image_info.text_encoder_pool2)
|
||||
captions.append(caption)
|
||||
|
||||
if not self.token_padding_disabled: # this option might be omitted in future
|
||||
elif image_info.text_encoder_outputs_npz is not None:
|
||||
text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk(
|
||||
image_info.text_encoder_outputs_npz
|
||||
)
|
||||
text_encoder_outputs1_list.append(text_encoder_outputs1)
|
||||
text_encoder_outputs2_list.append(text_encoder_outputs2)
|
||||
text_encoder_pool2_list.append(text_encoder_pool2)
|
||||
captions.append(caption)
|
||||
else:
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
if self.XTI_layers:
|
||||
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
||||
caption_layer = []
|
||||
for layer in self.XTI_layers:
|
||||
token_strings_from = " ".join(self.token_strings)
|
||||
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||
caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||
caption_layer.append(caption_)
|
||||
captions.append(caption_layer)
|
||||
else:
|
||||
token_caption = self.get_input_ids(caption, self.tokenizers[0])
|
||||
input_ids_list.append(token_caption)
|
||||
captions.append(caption)
|
||||
|
||||
if len(self.tokenizers) > 1:
|
||||
if not self.token_padding_disabled: # this option might be omitted in future
|
||||
if self.XTI_layers:
|
||||
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
|
||||
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
||||
else:
|
||||
token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||
input_ids2_list.append(token_caption2)
|
||||
token_caption = self.get_input_ids(caption, self.tokenizers[0])
|
||||
input_ids_list.append(token_caption)
|
||||
|
||||
if len(self.tokenizers) > 1:
|
||||
if self.XTI_layers:
|
||||
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
|
||||
else:
|
||||
token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||
input_ids2_list.append(token_caption2)
|
||||
|
||||
example = {}
|
||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||
|
||||
if self.token_padding_disabled:
|
||||
# padding=True means pad in the batch
|
||||
example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||
if len(self.tokenizers) > 1:
|
||||
# following may not work in SDXL, keep the line for future update
|
||||
example["input_ids2"] = self.tokenizer[1](captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||
if len(text_encoder_outputs1_list) == 0:
|
||||
if self.token_padding_disabled:
|
||||
# padding=True means pad in the batch
|
||||
example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||
if len(self.tokenizers) > 1:
|
||||
example["input_ids2"] = self.tokenizer[1](
|
||||
captions, padding=True, truncation=True, return_tensors="pt"
|
||||
).input_ids
|
||||
else:
|
||||
example["input_ids2"] = None
|
||||
else:
|
||||
example["input_ids"] = torch.stack(input_ids_list)
|
||||
example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None
|
||||
example["text_encoder_outputs1_list"] = None
|
||||
example["text_encoder_outputs2_list"] = None
|
||||
example["text_encoder_pool2_list"] = None
|
||||
else:
|
||||
example["input_ids"] = torch.stack(input_ids_list)
|
||||
example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None
|
||||
example["input_ids"] = None
|
||||
example["input_ids2"] = None
|
||||
# # for assertion
|
||||
# example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions])
|
||||
# example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions])
|
||||
example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list)
|
||||
example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list)
|
||||
example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list)
|
||||
|
||||
if images[0] is not None:
|
||||
images = torch.stack(images)
|
||||
@@ -1073,6 +1179,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def get_item_for_caching(self, bucket, bucket_batch_size, image_index):
|
||||
captions = []
|
||||
images = []
|
||||
input_ids1_list = []
|
||||
input_ids2_list = []
|
||||
absolute_paths = []
|
||||
resized_sizes = []
|
||||
bucket_reso = None
|
||||
@@ -1092,14 +1200,24 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
assert random_crop == subset.random_crop, "random_crop must be same in a batch"
|
||||
assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch"
|
||||
|
||||
caption = image_info.caption # TODO cache some patterns of droping, shuffling, etc.
|
||||
caption = image_info.caption # TODO cache some patterns of dropping, shuffling, etc.
|
||||
|
||||
if self.caching_mode == "latents":
|
||||
image = load_image(image_info.absolute_path)
|
||||
else:
|
||||
image = None
|
||||
|
||||
if self.caching_mode == "text":
|
||||
input_ids1 = self.get_input_ids(caption, self.tokenizers[0])
|
||||
input_ids2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||
else:
|
||||
input_ids1 = None
|
||||
input_ids2 = None
|
||||
|
||||
captions.append(caption)
|
||||
images.append(image)
|
||||
input_ids1_list.append(input_ids1)
|
||||
input_ids2_list.append(input_ids2)
|
||||
absolute_paths.append(image_info.absolute_path)
|
||||
resized_sizes.append(image_info.resized_size)
|
||||
|
||||
@@ -1110,6 +1228,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
example["images"] = images
|
||||
|
||||
example["captions"] = captions
|
||||
example["input_ids1_list"] = input_ids1_list
|
||||
example["input_ids2_list"] = input_ids2_list
|
||||
example["absolute_paths"] = absolute_paths
|
||||
example["resized_sizes"] = resized_sizes
|
||||
example["flip_aug"] = flip_aug
|
||||
@@ -1680,6 +1800,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
print(f"[Dataset {i}]")
|
||||
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
||||
|
||||
def cache_text_encoder_outputs(
|
||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||
):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
print(f"[Dataset {i}]")
|
||||
dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
|
||||
|
||||
def set_caching_mode(self, caching_mode):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_caching_mode(caching_mode)
|
||||
@@ -1982,6 +2109,7 @@ def cache_batch_latents(
|
||||
images = []
|
||||
for info in image_infos:
|
||||
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_left_top = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
||||
image = IMAGE_TRANSFORMS(image)
|
||||
images.append(image)
|
||||
@@ -2015,6 +2143,55 @@ def cache_batch_latents(
|
||||
info.latents_flipped = flipped_latent
|
||||
|
||||
|
||||
def cache_batch_text_encoder_outputs(
|
||||
image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype
|
||||
):
|
||||
input_ids1 = input_ids1.to(text_encoders[0].device)
|
||||
input_ids2 = input_ids2.to(text_encoders[1].device)
|
||||
|
||||
with torch.no_grad():
|
||||
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
|
||||
max_token_length,
|
||||
input_ids1,
|
||||
input_ids2,
|
||||
tokenizers[0],
|
||||
tokenizers[1],
|
||||
text_encoders[0],
|
||||
text_encoders[1],
|
||||
dtype,
|
||||
)
|
||||
|
||||
# ここでcpuに移動しておかないと、上書きされてしまう
|
||||
b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768
|
||||
b_hidden_state2 = b_hidden_state2.detach().to("cpu") # b,n*75+2,1280
|
||||
b_pool2 = b_pool2.detach().to("cpu") # b,1280
|
||||
|
||||
for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2):
|
||||
if cache_to_disk:
|
||||
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, hidden_state1, hidden_state2, pool2)
|
||||
else:
|
||||
info.text_encoder_outputs1 = hidden_state1
|
||||
info.text_encoder_outputs2 = hidden_state2
|
||||
info.text_encoder_pool2 = pool2
|
||||
|
||||
|
||||
def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2):
|
||||
np.savez(
|
||||
npz_path,
|
||||
hidden_state1=hidden_state1.cpu().float().numpy(),
|
||||
hidden_state2=hidden_state2.cpu().float().numpy(),
|
||||
pool2=pool2.cpu().float().numpy(),
|
||||
)
|
||||
|
||||
|
||||
def load_text_encoder_outputs_from_disk(npz_path):
|
||||
with np.load(npz_path) as f:
|
||||
hidden_state1 = torch.from_numpy(f["hidden_state1"])
|
||||
hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None
|
||||
pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None
|
||||
return hidden_state1, hidden_state2, pool2
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region モジュール入れ替え部
|
||||
@@ -3501,6 +3678,62 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
|
||||
return encoder_hidden_states
|
||||
|
||||
|
||||
def get_hidden_states_sdxl(
|
||||
max_token_length, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
|
||||
):
|
||||
# input_ids: b,n,77 -> b*n, 77
|
||||
b_size = input_ids1.size()[0]
|
||||
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
|
||||
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
|
||||
|
||||
# text_encoder1
|
||||
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
|
||||
hidden_states1 = enc_out["hidden_states"][11]
|
||||
|
||||
# text_encoder2
|
||||
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
||||
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
||||
pool2 = enc_out["text_embeds"]
|
||||
|
||||
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
||||
n_size = 1 if max_token_length is None else max_token_length // 75
|
||||
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
|
||||
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
|
||||
|
||||
if max_token_length is not None:
|
||||
# bs*3, 77, 768 or 1024
|
||||
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||
states_list = [hidden_states1[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, tokenizer1.model_max_length):
|
||||
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(hidden_states1[:, -1].unsqueeze(1)) # <EOS>
|
||||
hidden_states1 = torch.cat(states_list, dim=1)
|
||||
|
||||
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
||||
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, tokenizer2.model_max_length):
|
||||
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||
# this causes an error:
|
||||
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||
# if i > 1:
|
||||
# for j in range(len(chunk)): # batch_size
|
||||
# if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||
# chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||
hidden_states2 = torch.cat(states_list, dim=1)
|
||||
|
||||
# pool はnの最初のものを使う
|
||||
pool2 = pool2[::n_size]
|
||||
|
||||
if weight_dtype is not None:
|
||||
# this is required for additional network training
|
||||
hidden_states1 = hidden_states1.to(weight_dtype)
|
||||
hidden_states2 = hidden_states2.to(weight_dtype)
|
||||
|
||||
return hidden_states1, hidden_states2, pool2
|
||||
|
||||
|
||||
def default_if_none(value, default):
|
||||
return default if value is None else value
|
||||
|
||||
|
||||
Reference in New Issue
Block a user