mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add caching to disk for text encoder outputs
This commit is contained in:
@@ -140,62 +140,6 @@ def load_tokenizers(args: argparse.Namespace):
|
|||||||
return tokeniers
|
return tokeniers
|
||||||
|
|
||||||
|
|
||||||
def get_hidden_states(
|
|
||||||
args: argparse.Namespace, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
|
|
||||||
):
|
|
||||||
# input_ids: b,n,77 -> b*n, 77
|
|
||||||
b_size = input_ids1.size()[0]
|
|
||||||
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
|
|
||||||
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
|
|
||||||
|
|
||||||
# text_encoder1
|
|
||||||
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
|
|
||||||
hidden_states1 = enc_out["hidden_states"][11]
|
|
||||||
|
|
||||||
# text_encoder2
|
|
||||||
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
|
||||||
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
|
||||||
pool2 = enc_out["text_embeds"]
|
|
||||||
|
|
||||||
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
|
||||||
n_size = 1 if args.max_token_length is None else args.max_token_length // 75
|
|
||||||
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
|
|
||||||
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
|
|
||||||
|
|
||||||
if args.max_token_length is not None:
|
|
||||||
# bs*3, 77, 768 or 1024
|
|
||||||
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
|
||||||
states_list = [hidden_states1[:, 0].unsqueeze(1)] # <BOS>
|
|
||||||
for i in range(1, args.max_token_length, tokenizer1.model_max_length):
|
|
||||||
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
|
||||||
states_list.append(hidden_states1[:, -1].unsqueeze(1)) # <EOS>
|
|
||||||
hidden_states1 = torch.cat(states_list, dim=1)
|
|
||||||
|
|
||||||
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
|
||||||
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
|
|
||||||
for i in range(1, args.max_token_length, tokenizer2.model_max_length):
|
|
||||||
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
|
|
||||||
# this causes an error:
|
|
||||||
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
|
||||||
# if i > 1:
|
|
||||||
# for j in range(len(chunk)): # batch_size
|
|
||||||
# if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
|
||||||
# chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
|
||||||
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
|
||||||
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
|
||||||
hidden_states2 = torch.cat(states_list, dim=1)
|
|
||||||
|
|
||||||
# pool はnの最初のものを使う
|
|
||||||
pool2 = pool2[::n_size]
|
|
||||||
|
|
||||||
if weight_dtype is not None:
|
|
||||||
# this is required for additional network training
|
|
||||||
hidden_states1 = hidden_states1.to(weight_dtype)
|
|
||||||
hidden_states2 = hidden_states2.to(weight_dtype)
|
|
||||||
|
|
||||||
return hidden_states1, hidden_states2, pool2
|
|
||||||
|
|
||||||
|
|
||||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||||
"""
|
"""
|
||||||
Create sinusoidal timestep embeddings.
|
Create sinusoidal timestep embeddings.
|
||||||
@@ -391,6 +335,11 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_text_encoder_outputs_to_disk",
|
||||||
|
action="store_true",
|
||||||
|
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def verify_sdxl_training_args(args: argparse.Namespace):
|
def verify_sdxl_training_args(args: argparse.Namespace):
|
||||||
@@ -417,6 +366,13 @@ def verify_sdxl_training_args(args: argparse.Namespace):
|
|||||||
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||||
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||||
|
|
||||||
|
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||||
|
args.cache_text_encoder_outputs = True
|
||||||
|
print(
|
||||||
|
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
|
||||||
|
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def sample_images(*args, **kwargs):
|
def sample_images(*args, **kwargs):
|
||||||
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||||
|
|||||||
@@ -104,6 +104,8 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||||
|
|
||||||
|
|
||||||
class ImageInfo:
|
class ImageInfo:
|
||||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||||
@@ -122,6 +124,11 @@ class ImageInfo:
|
|||||||
self.latents_crop_left_top: Tuple[int, int] = None # original image crop left top, not latents crop left top
|
self.latents_crop_left_top: Tuple[int, int] = None # original image crop left top, not latents crop left top
|
||||||
self.cond_img_path: str = None
|
self.cond_img_path: str = None
|
||||||
self.image: Optional[Image.Image] = None # optional, original PIL Image
|
self.image: Optional[Image.Image] = None # optional, original PIL Image
|
||||||
|
# SDXL, optional
|
||||||
|
self.text_encoder_outputs_npz: Optional[str] = None
|
||||||
|
self.text_encoder_outputs1: Optional[torch.Tensor] = None
|
||||||
|
self.text_encoder_outputs2: Optional[torch.Tensor] = None
|
||||||
|
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
class BucketManager:
|
class BucketManager:
|
||||||
@@ -793,7 +800,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||||
# ちょっと速くした
|
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||||
print("caching latents.")
|
print("caching latents.")
|
||||||
|
|
||||||
image_infos = list(self.image_data.values())
|
image_infos = list(self.image_data.values())
|
||||||
@@ -841,9 +848,73 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
|
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
|
||||||
|
print("caching latents...")
|
||||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||||
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop)
|
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop)
|
||||||
|
|
||||||
|
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
||||||
|
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
|
||||||
|
# SD1/2に対応するにはv2のフラグを持つ必要があるので後回し
|
||||||
|
def cache_text_encoder_outputs(
|
||||||
|
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||||
|
):
|
||||||
|
assert len(tokenizers) == 2, "only support SDXL"
|
||||||
|
|
||||||
|
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
|
||||||
|
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||||
|
print("caching text encoder outputs.")
|
||||||
|
image_infos = list(self.image_data.values())
|
||||||
|
|
||||||
|
print("checking cache existence...")
|
||||||
|
image_infos_to_cache = []
|
||||||
|
for info in tqdm(image_infos):
|
||||||
|
# subset = self.image_to_subset[info.image_key]
|
||||||
|
if cache_to_disk:
|
||||||
|
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
|
||||||
|
info.text_encoder_outputs_npz = te_out_npz
|
||||||
|
|
||||||
|
if not is_main_process: # store to info only
|
||||||
|
continue
|
||||||
|
|
||||||
|
if os.path.exists(te_out_npz):
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_infos_to_cache.append(info)
|
||||||
|
|
||||||
|
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
|
||||||
|
return
|
||||||
|
|
||||||
|
# prepare tokenizers and text encoders
|
||||||
|
for text_encoder in text_encoders:
|
||||||
|
text_encoder.to(device)
|
||||||
|
if weight_dtype is not None:
|
||||||
|
text_encoder.to(dtype=weight_dtype)
|
||||||
|
|
||||||
|
# create batch
|
||||||
|
batch = []
|
||||||
|
batches = []
|
||||||
|
for info in image_infos_to_cache:
|
||||||
|
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
|
||||||
|
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
|
||||||
|
batch.append((info, input_ids1, input_ids2))
|
||||||
|
|
||||||
|
if len(batch) >= self.batch_size:
|
||||||
|
batches.append(batch)
|
||||||
|
batch = []
|
||||||
|
|
||||||
|
if len(batch) > 0:
|
||||||
|
batches.append(batch)
|
||||||
|
|
||||||
|
# iterate batches: call text encoder and cache outputs for memory or disk
|
||||||
|
print("caching text encoder outputs...")
|
||||||
|
for batch in tqdm(batches):
|
||||||
|
infos, input_ids1, input_ids2 = zip(*batch)
|
||||||
|
input_ids1 = torch.stack(input_ids1, dim=0)
|
||||||
|
input_ids2 = torch.stack(input_ids2, dim=0)
|
||||||
|
cache_batch_text_encoder_outputs(
|
||||||
|
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype
|
||||||
|
)
|
||||||
|
|
||||||
def get_image_size(self, image_path):
|
def get_image_size(self, image_path):
|
||||||
image = Image.open(image_path)
|
image = Image.open(image_path)
|
||||||
return image.size
|
return image.size
|
||||||
@@ -931,6 +1002,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
crop_top_lefts = []
|
crop_top_lefts = []
|
||||||
target_sizes_hw = []
|
target_sizes_hw = []
|
||||||
flippeds = [] # 変数名が微妙
|
flippeds = [] # 変数名が微妙
|
||||||
|
text_encoder_outputs1_list = []
|
||||||
|
text_encoder_outputs2_list = []
|
||||||
|
text_encoder_pool2_list = []
|
||||||
|
|
||||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||||
image_info = self.image_data[image_key]
|
image_info = self.image_data[image_key]
|
||||||
@@ -1012,44 +1086,76 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
target_sizes_hw.append((target_size[1], target_size[0]))
|
target_sizes_hw.append((target_size[1], target_size[0]))
|
||||||
flippeds.append(flipped)
|
flippeds.append(flipped)
|
||||||
|
|
||||||
caption = self.process_caption(subset, image_info.caption)
|
# captionとtext encoder outputを処理する
|
||||||
if self.XTI_layers:
|
caption = image_info.caption # default
|
||||||
caption_layer = []
|
if image_info.text_encoder_outputs1 is not None:
|
||||||
for layer in self.XTI_layers:
|
text_encoder_outputs1_list.append(image_info.text_encoder_outputs1)
|
||||||
token_strings_from = " ".join(self.token_strings)
|
text_encoder_outputs2_list.append(image_info.text_encoder_outputs2)
|
||||||
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
text_encoder_pool2_list.append(image_info.text_encoder_pool2)
|
||||||
caption_ = caption.replace(token_strings_from, token_strings_to)
|
|
||||||
caption_layer.append(caption_)
|
|
||||||
captions.append(caption_layer)
|
|
||||||
else:
|
|
||||||
captions.append(caption)
|
captions.append(caption)
|
||||||
|
elif image_info.text_encoder_outputs_npz is not None:
|
||||||
if not self.token_padding_disabled: # this option might be omitted in future
|
text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk(
|
||||||
|
image_info.text_encoder_outputs_npz
|
||||||
|
)
|
||||||
|
text_encoder_outputs1_list.append(text_encoder_outputs1)
|
||||||
|
text_encoder_outputs2_list.append(text_encoder_outputs2)
|
||||||
|
text_encoder_pool2_list.append(text_encoder_pool2)
|
||||||
|
captions.append(caption)
|
||||||
|
else:
|
||||||
|
caption = self.process_caption(subset, image_info.caption)
|
||||||
if self.XTI_layers:
|
if self.XTI_layers:
|
||||||
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
caption_layer = []
|
||||||
|
for layer in self.XTI_layers:
|
||||||
|
token_strings_from = " ".join(self.token_strings)
|
||||||
|
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||||
|
caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||||
|
caption_layer.append(caption_)
|
||||||
|
captions.append(caption_layer)
|
||||||
else:
|
else:
|
||||||
token_caption = self.get_input_ids(caption, self.tokenizers[0])
|
captions.append(caption)
|
||||||
input_ids_list.append(token_caption)
|
|
||||||
|
|
||||||
if len(self.tokenizers) > 1:
|
if not self.token_padding_disabled: # this option might be omitted in future
|
||||||
if self.XTI_layers:
|
if self.XTI_layers:
|
||||||
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
|
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
||||||
else:
|
else:
|
||||||
token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
|
token_caption = self.get_input_ids(caption, self.tokenizers[0])
|
||||||
input_ids2_list.append(token_caption2)
|
input_ids_list.append(token_caption)
|
||||||
|
|
||||||
|
if len(self.tokenizers) > 1:
|
||||||
|
if self.XTI_layers:
|
||||||
|
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
|
||||||
|
else:
|
||||||
|
token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||||
|
input_ids2_list.append(token_caption2)
|
||||||
|
|
||||||
example = {}
|
example = {}
|
||||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||||
|
|
||||||
if self.token_padding_disabled:
|
if len(text_encoder_outputs1_list) == 0:
|
||||||
# padding=True means pad in the batch
|
if self.token_padding_disabled:
|
||||||
example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
# padding=True means pad in the batch
|
||||||
if len(self.tokenizers) > 1:
|
example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||||
# following may not work in SDXL, keep the line for future update
|
if len(self.tokenizers) > 1:
|
||||||
example["input_ids2"] = self.tokenizer[1](captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
example["input_ids2"] = self.tokenizer[1](
|
||||||
|
captions, padding=True, truncation=True, return_tensors="pt"
|
||||||
|
).input_ids
|
||||||
|
else:
|
||||||
|
example["input_ids2"] = None
|
||||||
|
else:
|
||||||
|
example["input_ids"] = torch.stack(input_ids_list)
|
||||||
|
example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None
|
||||||
|
example["text_encoder_outputs1_list"] = None
|
||||||
|
example["text_encoder_outputs2_list"] = None
|
||||||
|
example["text_encoder_pool2_list"] = None
|
||||||
else:
|
else:
|
||||||
example["input_ids"] = torch.stack(input_ids_list)
|
example["input_ids"] = None
|
||||||
example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None
|
example["input_ids2"] = None
|
||||||
|
# # for assertion
|
||||||
|
# example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions])
|
||||||
|
# example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions])
|
||||||
|
example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list)
|
||||||
|
example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list)
|
||||||
|
example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list)
|
||||||
|
|
||||||
if images[0] is not None:
|
if images[0] is not None:
|
||||||
images = torch.stack(images)
|
images = torch.stack(images)
|
||||||
@@ -1073,6 +1179,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
def get_item_for_caching(self, bucket, bucket_batch_size, image_index):
|
def get_item_for_caching(self, bucket, bucket_batch_size, image_index):
|
||||||
captions = []
|
captions = []
|
||||||
images = []
|
images = []
|
||||||
|
input_ids1_list = []
|
||||||
|
input_ids2_list = []
|
||||||
absolute_paths = []
|
absolute_paths = []
|
||||||
resized_sizes = []
|
resized_sizes = []
|
||||||
bucket_reso = None
|
bucket_reso = None
|
||||||
@@ -1092,14 +1200,24 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
assert random_crop == subset.random_crop, "random_crop must be same in a batch"
|
assert random_crop == subset.random_crop, "random_crop must be same in a batch"
|
||||||
assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch"
|
assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch"
|
||||||
|
|
||||||
caption = image_info.caption # TODO cache some patterns of droping, shuffling, etc.
|
caption = image_info.caption # TODO cache some patterns of dropping, shuffling, etc.
|
||||||
|
|
||||||
if self.caching_mode == "latents":
|
if self.caching_mode == "latents":
|
||||||
image = load_image(image_info.absolute_path)
|
image = load_image(image_info.absolute_path)
|
||||||
else:
|
else:
|
||||||
image = None
|
image = None
|
||||||
|
|
||||||
|
if self.caching_mode == "text":
|
||||||
|
input_ids1 = self.get_input_ids(caption, self.tokenizers[0])
|
||||||
|
input_ids2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||||
|
else:
|
||||||
|
input_ids1 = None
|
||||||
|
input_ids2 = None
|
||||||
|
|
||||||
captions.append(caption)
|
captions.append(caption)
|
||||||
images.append(image)
|
images.append(image)
|
||||||
|
input_ids1_list.append(input_ids1)
|
||||||
|
input_ids2_list.append(input_ids2)
|
||||||
absolute_paths.append(image_info.absolute_path)
|
absolute_paths.append(image_info.absolute_path)
|
||||||
resized_sizes.append(image_info.resized_size)
|
resized_sizes.append(image_info.resized_size)
|
||||||
|
|
||||||
@@ -1110,6 +1228,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
example["images"] = images
|
example["images"] = images
|
||||||
|
|
||||||
example["captions"] = captions
|
example["captions"] = captions
|
||||||
|
example["input_ids1_list"] = input_ids1_list
|
||||||
|
example["input_ids2_list"] = input_ids2_list
|
||||||
example["absolute_paths"] = absolute_paths
|
example["absolute_paths"] = absolute_paths
|
||||||
example["resized_sizes"] = resized_sizes
|
example["resized_sizes"] = resized_sizes
|
||||||
example["flip_aug"] = flip_aug
|
example["flip_aug"] = flip_aug
|
||||||
@@ -1680,6 +1800,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
print(f"[Dataset {i}]")
|
print(f"[Dataset {i}]")
|
||||||
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
||||||
|
|
||||||
|
def cache_text_encoder_outputs(
|
||||||
|
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||||
|
):
|
||||||
|
for i, dataset in enumerate(self.datasets):
|
||||||
|
print(f"[Dataset {i}]")
|
||||||
|
dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
|
||||||
|
|
||||||
def set_caching_mode(self, caching_mode):
|
def set_caching_mode(self, caching_mode):
|
||||||
for dataset in self.datasets:
|
for dataset in self.datasets:
|
||||||
dataset.set_caching_mode(caching_mode)
|
dataset.set_caching_mode(caching_mode)
|
||||||
@@ -1982,6 +2109,7 @@ def cache_batch_latents(
|
|||||||
images = []
|
images = []
|
||||||
for info in image_infos:
|
for info in image_infos:
|
||||||
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
|
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
|
||||||
|
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||||
image, original_size, crop_left_top = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
image, original_size, crop_left_top = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
||||||
image = IMAGE_TRANSFORMS(image)
|
image = IMAGE_TRANSFORMS(image)
|
||||||
images.append(image)
|
images.append(image)
|
||||||
@@ -2015,6 +2143,55 @@ def cache_batch_latents(
|
|||||||
info.latents_flipped = flipped_latent
|
info.latents_flipped = flipped_latent
|
||||||
|
|
||||||
|
|
||||||
|
def cache_batch_text_encoder_outputs(
|
||||||
|
image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype
|
||||||
|
):
|
||||||
|
input_ids1 = input_ids1.to(text_encoders[0].device)
|
||||||
|
input_ids2 = input_ids2.to(text_encoders[1].device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
|
||||||
|
max_token_length,
|
||||||
|
input_ids1,
|
||||||
|
input_ids2,
|
||||||
|
tokenizers[0],
|
||||||
|
tokenizers[1],
|
||||||
|
text_encoders[0],
|
||||||
|
text_encoders[1],
|
||||||
|
dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ここでcpuに移動しておかないと、上書きされてしまう
|
||||||
|
b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768
|
||||||
|
b_hidden_state2 = b_hidden_state2.detach().to("cpu") # b,n*75+2,1280
|
||||||
|
b_pool2 = b_pool2.detach().to("cpu") # b,1280
|
||||||
|
|
||||||
|
for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2):
|
||||||
|
if cache_to_disk:
|
||||||
|
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, hidden_state1, hidden_state2, pool2)
|
||||||
|
else:
|
||||||
|
info.text_encoder_outputs1 = hidden_state1
|
||||||
|
info.text_encoder_outputs2 = hidden_state2
|
||||||
|
info.text_encoder_pool2 = pool2
|
||||||
|
|
||||||
|
|
||||||
|
def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2):
|
||||||
|
np.savez(
|
||||||
|
npz_path,
|
||||||
|
hidden_state1=hidden_state1.cpu().float().numpy(),
|
||||||
|
hidden_state2=hidden_state2.cpu().float().numpy(),
|
||||||
|
pool2=pool2.cpu().float().numpy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_text_encoder_outputs_from_disk(npz_path):
|
||||||
|
with np.load(npz_path) as f:
|
||||||
|
hidden_state1 = torch.from_numpy(f["hidden_state1"])
|
||||||
|
hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None
|
||||||
|
pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None
|
||||||
|
return hidden_state1, hidden_state2, pool2
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region モジュール入れ替え部
|
# region モジュール入れ替え部
|
||||||
@@ -3501,6 +3678,62 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
|
|||||||
return encoder_hidden_states
|
return encoder_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def get_hidden_states_sdxl(
|
||||||
|
max_token_length, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
|
||||||
|
):
|
||||||
|
# input_ids: b,n,77 -> b*n, 77
|
||||||
|
b_size = input_ids1.size()[0]
|
||||||
|
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
|
||||||
|
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
|
||||||
|
|
||||||
|
# text_encoder1
|
||||||
|
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
|
||||||
|
hidden_states1 = enc_out["hidden_states"][11]
|
||||||
|
|
||||||
|
# text_encoder2
|
||||||
|
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
||||||
|
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
||||||
|
pool2 = enc_out["text_embeds"]
|
||||||
|
|
||||||
|
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
||||||
|
n_size = 1 if max_token_length is None else max_token_length // 75
|
||||||
|
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
|
||||||
|
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
|
||||||
|
|
||||||
|
if max_token_length is not None:
|
||||||
|
# bs*3, 77, 768 or 1024
|
||||||
|
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||||
|
states_list = [hidden_states1[:, 0].unsqueeze(1)] # <BOS>
|
||||||
|
for i in range(1, max_token_length, tokenizer1.model_max_length):
|
||||||
|
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
||||||
|
states_list.append(hidden_states1[:, -1].unsqueeze(1)) # <EOS>
|
||||||
|
hidden_states1 = torch.cat(states_list, dim=1)
|
||||||
|
|
||||||
|
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
||||||
|
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
|
||||||
|
for i in range(1, max_token_length, tokenizer2.model_max_length):
|
||||||
|
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||||
|
# this causes an error:
|
||||||
|
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||||
|
# if i > 1:
|
||||||
|
# for j in range(len(chunk)): # batch_size
|
||||||
|
# if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||||
|
# chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||||
|
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||||
|
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||||
|
hidden_states2 = torch.cat(states_list, dim=1)
|
||||||
|
|
||||||
|
# pool はnの最初のものを使う
|
||||||
|
pool2 = pool2[::n_size]
|
||||||
|
|
||||||
|
if weight_dtype is not None:
|
||||||
|
# this is required for additional network training
|
||||||
|
hidden_states1 = hidden_states1.to(weight_dtype)
|
||||||
|
hidden_states2 = hidden_states2.to(weight_dtype)
|
||||||
|
|
||||||
|
return hidden_states1, hidden_states2, pool2
|
||||||
|
|
||||||
|
|
||||||
def default_if_none(value, default):
|
def default_if_none(value, default):
|
||||||
return default if value is None else value
|
return default if value is None else value
|
||||||
|
|
||||||
|
|||||||
@@ -204,10 +204,6 @@ def train(args):
|
|||||||
text_encoder2.gradient_checkpointing_enable()
|
text_encoder2.gradient_checkpointing_enable()
|
||||||
training_models.append(text_encoder1)
|
training_models.append(text_encoder1)
|
||||||
training_models.append(text_encoder2)
|
training_models.append(text_encoder2)
|
||||||
|
|
||||||
text_encoder1_cache = None
|
|
||||||
text_encoder2_cache = None
|
|
||||||
|
|
||||||
# set require_grad=True later
|
# set require_grad=True later
|
||||||
else:
|
else:
|
||||||
text_encoder1.requires_grad_(False)
|
text_encoder1.requires_grad_(False)
|
||||||
@@ -218,9 +214,15 @@ def train(args):
|
|||||||
# TextEncoderの出力をキャッシュする
|
# TextEncoderの出力をキャッシュする
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
# Text Encodes are eval and no grad
|
# Text Encodes are eval and no grad
|
||||||
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
|
with torch.no_grad():
|
||||||
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataset_group, None
|
train_dataset_group.cache_text_encoder_outputs(
|
||||||
)
|
(tokenizer1, tokenizer2),
|
||||||
|
(text_encoder1, text_encoder2),
|
||||||
|
accelerator.device,
|
||||||
|
None,
|
||||||
|
args.cache_text_encoder_outputs_to_disk,
|
||||||
|
accelerator.is_main_process,
|
||||||
|
)
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
if not cache_latents:
|
if not cache_latents:
|
||||||
@@ -375,11 +377,10 @@ def train(args):
|
|||||||
accelerator.print("NaN found in latents, replacing with zeros")
|
accelerator.print("NaN found in latents, replacing with zeros")
|
||||||
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
||||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||||
b_size = latents.shape[0]
|
|
||||||
|
|
||||||
input_ids1 = batch["input_ids"]
|
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||||
input_ids2 = batch["input_ids2"]
|
input_ids1 = batch["input_ids"]
|
||||||
if not args.cache_text_encoder_outputs:
|
input_ids2 = batch["input_ids2"]
|
||||||
with torch.set_grad_enabled(args.train_text_encoder):
|
with torch.set_grad_enabled(args.train_text_encoder):
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
# TODO support weighted captions
|
# TODO support weighted captions
|
||||||
@@ -395,8 +396,8 @@ def train(args):
|
|||||||
# else:
|
# else:
|
||||||
input_ids1 = input_ids1.to(accelerator.device)
|
input_ids1 = input_ids1.to(accelerator.device)
|
||||||
input_ids2 = input_ids2.to(accelerator.device)
|
input_ids2 = input_ids2.to(accelerator.device)
|
||||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states(
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
||||||
args,
|
args.max_token_length,
|
||||||
input_ids1,
|
input_ids1,
|
||||||
input_ids2,
|
input_ids2,
|
||||||
tokenizer1,
|
tokenizer1,
|
||||||
@@ -406,19 +407,26 @@ def train(args):
|
|||||||
None if not args.full_fp16 else weight_dtype,
|
None if not args.full_fp16 else weight_dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
encoder_hidden_states1 = []
|
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
||||||
encoder_hidden_states2 = []
|
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
|
||||||
pool2 = []
|
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
||||||
for input_id1, input_id2 in zip(input_ids1, input_ids2):
|
|
||||||
input_id1_cache_key = tuple(input_id1.squeeze(0).flatten().tolist())
|
# # verify that the text encoder outputs are correct
|
||||||
input_id2_cache_key = tuple(input_id2.squeeze(0).flatten().tolist())
|
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
|
||||||
encoder_hidden_states1.append(text_encoder1_cache[input_id1_cache_key])
|
# args.max_token_length,
|
||||||
hidden_states2, p2 = text_encoder2_cache[input_id2_cache_key]
|
# batch["input_ids"].to(text_encoder1.device),
|
||||||
encoder_hidden_states2.append(hidden_states2)
|
# batch["input_ids2"].to(text_encoder1.device),
|
||||||
pool2.append(p2)
|
# tokenizer1,
|
||||||
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)
|
# tokenizer2,
|
||||||
encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype)
|
# text_encoder1,
|
||||||
pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype)
|
# text_encoder2,
|
||||||
|
# None if not args.full_fp16 else weight_dtype,
|
||||||
|
# )
|
||||||
|
# b_size = encoder_hidden_states1.shape[0]
|
||||||
|
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||||
|
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||||
|
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||||
|
# print("text encoder outputs verified")
|
||||||
|
|
||||||
# get size embeddings
|
# get size embeddings
|
||||||
orig_size = batch["original_sizes_hw"]
|
orig_size = batch["original_sizes_hw"]
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
return args.cache_text_encoder_outputs
|
return args.cache_text_encoder_outputs
|
||||||
|
|
||||||
def cache_text_encoder_outputs_if_needed(
|
def cache_text_encoder_outputs_if_needed(
|
||||||
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset, weight_dtype
|
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
||||||
):
|
):
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
if not args.lowram:
|
if not args.lowram:
|
||||||
@@ -60,34 +60,33 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
|
dataset.cache_text_encoder_outputs(
|
||||||
args, accelerator, tokenizers, text_encoders, dataset, weight_dtype
|
tokenizers,
|
||||||
|
text_encoders,
|
||||||
|
accelerator.device,
|
||||||
|
weight_dtype,
|
||||||
|
args.cache_text_encoder_outputs_to_disk,
|
||||||
|
accelerator.is_main_process,
|
||||||
)
|
)
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||||
text_encoders[1].to("cpu", dtype=torch.float32)
|
text_encoders[1].to("cpu", dtype=torch.float32)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
self.text_encoder1_cache = text_encoder1_cache
|
|
||||||
self.text_encoder2_cache = text_encoder2_cache
|
|
||||||
|
|
||||||
if not args.lowram:
|
if not args.lowram:
|
||||||
print("move vae and unet back to original device")
|
print("move vae and unet back to original device")
|
||||||
vae.to(org_vae_device)
|
vae.to(org_vae_device)
|
||||||
unet.to(org_unet_device)
|
unet.to(org_unet_device)
|
||||||
else:
|
else:
|
||||||
self.text_encoder1_cache = None
|
|
||||||
self.text_encoder2_cache = None
|
|
||||||
|
|
||||||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||||
text_encoders[0].to(accelerator.device)
|
text_encoders[0].to(accelerator.device)
|
||||||
text_encoders[1].to(accelerator.device)
|
text_encoders[1].to(accelerator.device)
|
||||||
|
|
||||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||||
input_ids1 = batch["input_ids"]
|
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||||
input_ids2 = batch["input_ids2"]
|
input_ids1 = batch["input_ids"]
|
||||||
if not args.cache_text_encoder_outputs:
|
input_ids2 = batch["input_ids2"]
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
# TODO support weighted captions
|
# TODO support weighted captions
|
||||||
@@ -103,8 +102,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
# else:
|
# else:
|
||||||
input_ids1 = input_ids1.to(accelerator.device)
|
input_ids1 = input_ids1.to(accelerator.device)
|
||||||
input_ids2 = input_ids2.to(accelerator.device)
|
input_ids2 = input_ids2.to(accelerator.device)
|
||||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states(
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
||||||
args,
|
args.max_token_length,
|
||||||
input_ids1,
|
input_ids1,
|
||||||
input_ids2,
|
input_ids2,
|
||||||
tokenizers[0],
|
tokenizers[0],
|
||||||
@@ -114,19 +113,27 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
None if not args.full_fp16 else weight_dtype,
|
None if not args.full_fp16 else weight_dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
encoder_hidden_states1 = []
|
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
||||||
encoder_hidden_states2 = []
|
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
|
||||||
pool2 = []
|
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
||||||
for input_id1, input_id2 in zip(input_ids1, input_ids2):
|
|
||||||
input_id1_cache_key = tuple(input_id1.flatten().tolist())
|
# # verify that the text encoder outputs are correct
|
||||||
input_id2_cache_key = tuple(input_id2.flatten().tolist())
|
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
|
||||||
encoder_hidden_states1.append(self.text_encoder1_cache[input_id1_cache_key])
|
# args.max_token_length,
|
||||||
hidden_states2, p2 = self.text_encoder2_cache[input_id2_cache_key]
|
# batch["input_ids"].to(text_encoders[0].device),
|
||||||
encoder_hidden_states2.append(hidden_states2)
|
# batch["input_ids2"].to(text_encoders[0].device),
|
||||||
pool2.append(p2)
|
# tokenizers[0],
|
||||||
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)
|
# tokenizers[1],
|
||||||
encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype)
|
# text_encoders[0],
|
||||||
pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype)
|
# text_encoders[1],
|
||||||
|
# None if not args.full_fp16 else weight_dtype,
|
||||||
|
# )
|
||||||
|
# b_size = encoder_hidden_states1.shape[0]
|
||||||
|
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||||
|
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||||
|
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||||
|
# print("text encoder outputs verified")
|
||||||
|
|
||||||
|
|
||||||
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||||
|
|
||||||
|
|||||||
@@ -45,8 +45,8 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
|||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
input_ids1 = input_ids1.to(accelerator.device)
|
input_ids1 = input_ids1.to(accelerator.device)
|
||||||
input_ids2 = input_ids2.to(accelerator.device)
|
input_ids2 = input_ids2.to(accelerator.device)
|
||||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states(
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
||||||
args,
|
args.max_token_length,
|
||||||
input_ids1,
|
input_ids1,
|
||||||
input_ids2,
|
input_ids2,
|
||||||
tokenizers[0],
|
tokenizers[0],
|
||||||
|
|||||||
191
tools/cache_text_encoder_outputs.py
Normal file
191
tools/cache_text_encoder_outputs.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
from multiprocessing import Value
|
||||||
|
import os
|
||||||
|
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from library import config_util
|
||||||
|
from library import train_util
|
||||||
|
from library import sdxl_train_util
|
||||||
|
from library.config_util import (
|
||||||
|
ConfigSanitizer,
|
||||||
|
BlueprintGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||||
|
train_util.prepare_dataset_args(args, True)
|
||||||
|
|
||||||
|
# check cache arg
|
||||||
|
assert (
|
||||||
|
args.cache_text_encoder_outputs_to_disk
|
||||||
|
), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
|
||||||
|
|
||||||
|
# できるだけ準備はしておくが今のところSDXLのみしか動かない
|
||||||
|
assert (
|
||||||
|
args.sdxl
|
||||||
|
), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です"
|
||||||
|
|
||||||
|
use_dreambooth_method = args.in_json is None
|
||||||
|
|
||||||
|
if args.seed is not None:
|
||||||
|
set_seed(args.seed) # 乱数系列を初期化する
|
||||||
|
|
||||||
|
# tokenizerを準備する:datasetを動かすために必要
|
||||||
|
if args.sdxl:
|
||||||
|
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
||||||
|
tokenizers = [tokenizer1, tokenizer2]
|
||||||
|
else:
|
||||||
|
tokenizer = train_util.load_tokenizer(args)
|
||||||
|
tokenizers = [tokenizer]
|
||||||
|
|
||||||
|
# データセットを準備する
|
||||||
|
if args.dataset_class is None:
|
||||||
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||||
|
if args.dataset_config is not None:
|
||||||
|
print(f"Load dataset config from {args.dataset_config}")
|
||||||
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
|
ignored = ["train_data_dir", "in_json"]
|
||||||
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
|
print(
|
||||||
|
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
|
", ".join(ignored)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if use_dreambooth_method:
|
||||||
|
print("Using DreamBooth method.")
|
||||||
|
user_config = {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||||
|
args.train_data_dir, args.reg_data_dir
|
||||||
|
)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
print("Training with captions.")
|
||||||
|
user_config = {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"subsets": [
|
||||||
|
{
|
||||||
|
"image_dir": args.train_data_dir,
|
||||||
|
"metadata_file": args.in_json,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
|
||||||
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
else:
|
||||||
|
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
|
||||||
|
|
||||||
|
current_epoch = Value("i", 0)
|
||||||
|
current_step = Value("i", 0)
|
||||||
|
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||||
|
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||||
|
|
||||||
|
# acceleratorを準備する
|
||||||
|
print("prepare accelerator")
|
||||||
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
|
|
||||||
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
|
weight_dtype, _ = train_util.prepare_dtype(args)
|
||||||
|
|
||||||
|
# モデルを読み込む
|
||||||
|
print("load model")
|
||||||
|
if args.sdxl:
|
||||||
|
(_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
||||||
|
text_encoders = [text_encoder1, text_encoder2]
|
||||||
|
else:
|
||||||
|
text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||||
|
text_encoders = [text_encoder1]
|
||||||
|
|
||||||
|
for text_encoder in text_encoders:
|
||||||
|
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
text_encoder.requires_grad_(False)
|
||||||
|
text_encoder.eval()
|
||||||
|
|
||||||
|
# dataloaderを準備する
|
||||||
|
train_dataset_group.set_caching_mode("text")
|
||||||
|
|
||||||
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||||
|
|
||||||
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset_group,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=collater,
|
||||||
|
num_workers=n_workers,
|
||||||
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
|
||||||
|
train_dataloader = accelerator.prepare(train_dataloader)
|
||||||
|
|
||||||
|
# データ取得のためのループ
|
||||||
|
for batch in tqdm(train_dataloader):
|
||||||
|
absolute_paths = batch["absolute_paths"]
|
||||||
|
input_ids1_list = batch["input_ids1_list"]
|
||||||
|
input_ids2_list = batch["input_ids2_list"]
|
||||||
|
|
||||||
|
image_infos = []
|
||||||
|
for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list):
|
||||||
|
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
|
||||||
|
image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
|
||||||
|
image_info
|
||||||
|
|
||||||
|
if args.skip_existing:
|
||||||
|
if os.path.exists(image_info.text_encoder_outputs_npz):
|
||||||
|
print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_info.input_ids1 = input_ids1
|
||||||
|
image_info.input_ids2 = input_ids2
|
||||||
|
image_infos.append(image_info)
|
||||||
|
|
||||||
|
if len(image_infos) > 0:
|
||||||
|
b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
|
||||||
|
b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos])
|
||||||
|
train_util.cache_batch_text_encoder_outputs(
|
||||||
|
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
train_util.add_sd_models_arguments(parser)
|
||||||
|
train_util.add_training_arguments(parser, True)
|
||||||
|
train_util.add_dataset_arguments(parser, True, True, True)
|
||||||
|
config_util.add_config_arguments(parser)
|
||||||
|
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||||
|
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_existing",
|
||||||
|
action="store_true",
|
||||||
|
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = setup_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
|
cache_to_disk(args)
|
||||||
Reference in New Issue
Block a user