diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index a93ef502..48785ca6 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -140,62 +140,6 @@ def load_tokenizers(args: argparse.Namespace): return tokeniers -def get_hidden_states( - args: argparse.Namespace, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None -): - # input_ids: b,n,77 -> b*n, 77 - b_size = input_ids1.size()[0] - input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 - input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 - - # text_encoder1 - enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True) - hidden_states1 = enc_out["hidden_states"][11] - - # text_encoder2 - enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) - hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer - pool2 = enc_out["text_embeds"] - - # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 - n_size = 1 if args.max_token_length is None else args.max_token_length // 75 - hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1])) - hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1])) - - if args.max_token_length is not None: - # bs*3, 77, 768 or 1024 - # encoder1: ... の三連を ... へ戻す - states_list = [hidden_states1[:, 0].unsqueeze(1)] # - for i in range(1, args.max_token_length, tokenizer1.model_max_length): - states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # の後から の前まで - states_list.append(hidden_states1[:, -1].unsqueeze(1)) # - hidden_states1 = torch.cat(states_list, dim=1) - - # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん - states_list = [hidden_states2[:, 0].unsqueeze(1)] # - for i in range(1, args.max_token_length, tokenizer2.model_max_length): - chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # の後から 最後の前まで - # this causes an error: - # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation - # if i > 1: - # for j in range(len(chunk)): # batch_size - # if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり ...のパターン - # chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする - states_list.append(chunk) # の後から の前まで - states_list.append(hidden_states2[:, -1].unsqueeze(1)) # のどちらか - hidden_states2 = torch.cat(states_list, dim=1) - - # pool はnの最初のものを使う - pool2 = pool2[::n_size] - - if weight_dtype is not None: - # this is required for additional network training - hidden_states1 = hidden_states1.to(weight_dtype) - hidden_states2 = hidden_states2.to(weight_dtype) - - return hidden_states1, hidden_states2, pool2 - - def timestep_embedding(timesteps, dim, max_period=10000): """ Create sinusoidal timestep embeddings. @@ -391,6 +335,11 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) def verify_sdxl_training_args(args: argparse.Namespace): @@ -417,6 +366,13 @@ def verify_sdxl_training_args(args: argparse.Namespace): not hasattr(args, "weighted_captions") or not args.weighted_captions ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + args.cache_text_encoder_outputs = True + print( + "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " + + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" + ) + def sample_images(*args, **kwargs): return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) diff --git a/library/train_util.py b/library/train_util.py index bf333299..c36651e2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -104,6 +104,8 @@ IMAGE_TRANSFORMS = transforms.Compose( ] ) +TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" + class ImageInfo: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: @@ -122,6 +124,11 @@ class ImageInfo: self.latents_crop_left_top: Tuple[int, int] = None # original image crop left top, not latents crop left top self.cond_img_path: str = None self.image: Optional[Image.Image] = None # optional, original PIL Image + # SDXL, optional + self.text_encoder_outputs_npz: Optional[str] = None + self.text_encoder_outputs1: Optional[torch.Tensor] = None + self.text_encoder_outputs2: Optional[torch.Tensor] = None + self.text_encoder_pool2: Optional[torch.Tensor] = None class BucketManager: @@ -793,7 +800,7 @@ class BaseDataset(torch.utils.data.Dataset): ) def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): - # ちょっと速くした + # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと print("caching latents.") image_infos = list(self.image_data.values()) @@ -841,9 +848,73 @@ class BaseDataset(torch.utils.data.Dataset): return # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded + print("caching latents...") for batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop) + # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる + # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する + # SD1/2に対応するにはv2のフラグを持つ必要があるので後回し + def cache_text_encoder_outputs( + self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True + ): + assert len(tokenizers) == 2, "only support SDXL" + + # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する + # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと + print("caching text encoder outputs.") + image_infos = list(self.image_data.values()) + + print("checking cache existence...") + image_infos_to_cache = [] + for info in tqdm(image_infos): + # subset = self.image_to_subset[info.image_key] + if cache_to_disk: + te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX + info.text_encoder_outputs_npz = te_out_npz + + if not is_main_process: # store to info only + continue + + if os.path.exists(te_out_npz): + continue + + image_infos_to_cache.append(info) + + if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only + return + + # prepare tokenizers and text encoders + for text_encoder in text_encoders: + text_encoder.to(device) + if weight_dtype is not None: + text_encoder.to(dtype=weight_dtype) + + # create batch + batch = [] + batches = [] + for info in image_infos_to_cache: + input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) + input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) + batch.append((info, input_ids1, input_ids2)) + + if len(batch) >= self.batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + # iterate batches: call text encoder and cache outputs for memory or disk + print("caching text encoder outputs...") + for batch in tqdm(batches): + infos, input_ids1, input_ids2 = zip(*batch) + input_ids1 = torch.stack(input_ids1, dim=0) + input_ids2 = torch.stack(input_ids2, dim=0) + cache_batch_text_encoder_outputs( + infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype + ) + def get_image_size(self, image_path): image = Image.open(image_path) return image.size @@ -931,6 +1002,9 @@ class BaseDataset(torch.utils.data.Dataset): crop_top_lefts = [] target_sizes_hw = [] flippeds = [] # 変数名が微妙 + text_encoder_outputs1_list = [] + text_encoder_outputs2_list = [] + text_encoder_pool2_list = [] for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] @@ -1012,44 +1086,76 @@ class BaseDataset(torch.utils.data.Dataset): target_sizes_hw.append((target_size[1], target_size[0])) flippeds.append(flipped) - caption = self.process_caption(subset, image_info.caption) - if self.XTI_layers: - caption_layer = [] - for layer in self.XTI_layers: - token_strings_from = " ".join(self.token_strings) - token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) - caption_ = caption.replace(token_strings_from, token_strings_to) - caption_layer.append(caption_) - captions.append(caption_layer) - else: + # captionとtext encoder outputを処理する + caption = image_info.caption # default + if image_info.text_encoder_outputs1 is not None: + text_encoder_outputs1_list.append(image_info.text_encoder_outputs1) + text_encoder_outputs2_list.append(image_info.text_encoder_outputs2) + text_encoder_pool2_list.append(image_info.text_encoder_pool2) captions.append(caption) - - if not self.token_padding_disabled: # this option might be omitted in future + elif image_info.text_encoder_outputs_npz is not None: + text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk( + image_info.text_encoder_outputs_npz + ) + text_encoder_outputs1_list.append(text_encoder_outputs1) + text_encoder_outputs2_list.append(text_encoder_outputs2) + text_encoder_pool2_list.append(text_encoder_pool2) + captions.append(caption) + else: + caption = self.process_caption(subset, image_info.caption) if self.XTI_layers: - token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) + caption_layer = [] + for layer in self.XTI_layers: + token_strings_from = " ".join(self.token_strings) + token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + caption_ = caption.replace(token_strings_from, token_strings_to) + caption_layer.append(caption_) + captions.append(caption_layer) else: - token_caption = self.get_input_ids(caption, self.tokenizers[0]) - input_ids_list.append(token_caption) + captions.append(caption) - if len(self.tokenizers) > 1: + if not self.token_padding_disabled: # this option might be omitted in future if self.XTI_layers: - token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) + token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) else: - token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) - input_ids2_list.append(token_caption2) + token_caption = self.get_input_ids(caption, self.tokenizers[0]) + input_ids_list.append(token_caption) + + if len(self.tokenizers) > 1: + if self.XTI_layers: + token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) + else: + token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) + input_ids2_list.append(token_caption2) example = {} example["loss_weights"] = torch.FloatTensor(loss_weights) - if self.token_padding_disabled: - # padding=True means pad in the batch - example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids - if len(self.tokenizers) > 1: - # following may not work in SDXL, keep the line for future update - example["input_ids2"] = self.tokenizer[1](captions, padding=True, truncation=True, return_tensors="pt").input_ids + if len(text_encoder_outputs1_list) == 0: + if self.token_padding_disabled: + # padding=True means pad in the batch + example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids + if len(self.tokenizers) > 1: + example["input_ids2"] = self.tokenizer[1]( + captions, padding=True, truncation=True, return_tensors="pt" + ).input_ids + else: + example["input_ids2"] = None + else: + example["input_ids"] = torch.stack(input_ids_list) + example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None + example["text_encoder_outputs1_list"] = None + example["text_encoder_outputs2_list"] = None + example["text_encoder_pool2_list"] = None else: - example["input_ids"] = torch.stack(input_ids_list) - example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None + example["input_ids"] = None + example["input_ids2"] = None + # # for assertion + # example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions]) + # example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions]) + example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list) + example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list) + example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list) if images[0] is not None: images = torch.stack(images) @@ -1073,6 +1179,8 @@ class BaseDataset(torch.utils.data.Dataset): def get_item_for_caching(self, bucket, bucket_batch_size, image_index): captions = [] images = [] + input_ids1_list = [] + input_ids2_list = [] absolute_paths = [] resized_sizes = [] bucket_reso = None @@ -1092,14 +1200,24 @@ class BaseDataset(torch.utils.data.Dataset): assert random_crop == subset.random_crop, "random_crop must be same in a batch" assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch" - caption = image_info.caption # TODO cache some patterns of droping, shuffling, etc. + caption = image_info.caption # TODO cache some patterns of dropping, shuffling, etc. + if self.caching_mode == "latents": image = load_image(image_info.absolute_path) else: image = None + if self.caching_mode == "text": + input_ids1 = self.get_input_ids(caption, self.tokenizers[0]) + input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) + else: + input_ids1 = None + input_ids2 = None + captions.append(caption) images.append(image) + input_ids1_list.append(input_ids1) + input_ids2_list.append(input_ids2) absolute_paths.append(image_info.absolute_path) resized_sizes.append(image_info.resized_size) @@ -1110,6 +1228,8 @@ class BaseDataset(torch.utils.data.Dataset): example["images"] = images example["captions"] = captions + example["input_ids1_list"] = input_ids1_list + example["input_ids2_list"] = input_ids2_list example["absolute_paths"] = absolute_paths example["resized_sizes"] = resized_sizes example["flip_aug"] = flip_aug @@ -1680,6 +1800,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset): print(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + def cache_text_encoder_outputs( + self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True + ): + for i, dataset in enumerate(self.datasets): + print(f"[Dataset {i}]") + dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) + def set_caching_mode(self, caching_mode): for dataset in self.datasets: dataset.set_caching_mode(caching_mode) @@ -1982,6 +2109,7 @@ def cache_batch_latents( images = [] for info in image_infos: image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8) + # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 image, original_size, crop_left_top = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) image = IMAGE_TRANSFORMS(image) images.append(image) @@ -2015,6 +2143,55 @@ def cache_batch_latents( info.latents_flipped = flipped_latent +def cache_batch_text_encoder_outputs( + image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype +): + input_ids1 = input_ids1.to(text_encoders[0].device) + input_ids2 = input_ids2.to(text_encoders[1].device) + + with torch.no_grad(): + b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl( + max_token_length, + input_ids1, + input_ids2, + tokenizers[0], + tokenizers[1], + text_encoders[0], + text_encoders[1], + dtype, + ) + + # ここでcpuに移動しておかないと、上書きされてしまう + b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768 + b_hidden_state2 = b_hidden_state2.detach().to("cpu") # b,n*75+2,1280 + b_pool2 = b_pool2.detach().to("cpu") # b,1280 + + for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2): + if cache_to_disk: + save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, hidden_state1, hidden_state2, pool2) + else: + info.text_encoder_outputs1 = hidden_state1 + info.text_encoder_outputs2 = hidden_state2 + info.text_encoder_pool2 = pool2 + + +def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2): + np.savez( + npz_path, + hidden_state1=hidden_state1.cpu().float().numpy(), + hidden_state2=hidden_state2.cpu().float().numpy(), + pool2=pool2.cpu().float().numpy(), + ) + + +def load_text_encoder_outputs_from_disk(npz_path): + with np.load(npz_path) as f: + hidden_state1 = torch.from_numpy(f["hidden_state1"]) + hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None + pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None + return hidden_state1, hidden_state2, pool2 + + # endregion # region モジュール入れ替え部 @@ -3501,6 +3678,62 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod return encoder_hidden_states +def get_hidden_states_sdxl( + max_token_length, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None +): + # input_ids: b,n,77 -> b*n, 77 + b_size = input_ids1.size()[0] + input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 + input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 + + # text_encoder1 + enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True) + hidden_states1 = enc_out["hidden_states"][11] + + # text_encoder2 + enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) + hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer + pool2 = enc_out["text_embeds"] + + # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 + n_size = 1 if max_token_length is None else max_token_length // 75 + hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1])) + hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1])) + + if max_token_length is not None: + # bs*3, 77, 768 or 1024 + # encoder1: ... の三連を ... へ戻す + states_list = [hidden_states1[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer1.model_max_length): + states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # の後から の前まで + states_list.append(hidden_states1[:, -1].unsqueeze(1)) # + hidden_states1 = torch.cat(states_list, dim=1) + + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [hidden_states2[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer2.model_max_length): + chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # の後から 最後の前まで + # this causes an error: + # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation + # if i > 1: + # for j in range(len(chunk)): # batch_size + # if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり ...のパターン + # chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(hidden_states2[:, -1].unsqueeze(1)) # のどちらか + hidden_states2 = torch.cat(states_list, dim=1) + + # pool はnの最初のものを使う + pool2 = pool2[::n_size] + + if weight_dtype is not None: + # this is required for additional network training + hidden_states1 = hidden_states1.to(weight_dtype) + hidden_states2 = hidden_states2.to(weight_dtype) + + return hidden_states1, hidden_states2, pool2 + + def default_if_none(value, default): return default if value is None else value diff --git a/sdxl_train.py b/sdxl_train.py index 935992bf..8459671c 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -204,10 +204,6 @@ def train(args): text_encoder2.gradient_checkpointing_enable() training_models.append(text_encoder1) training_models.append(text_encoder2) - - text_encoder1_cache = None - text_encoder2_cache = None - # set require_grad=True later else: text_encoder1.requires_grad_(False) @@ -218,9 +214,15 @@ def train(args): # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs( - args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataset_group, None - ) + with torch.no_grad(): + train_dataset_group.cache_text_encoder_outputs( + (tokenizer1, tokenizer2), + (text_encoder1, text_encoder2), + accelerator.device, + None, + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + ) accelerator.wait_for_everyone() if not cache_latents: @@ -375,11 +377,10 @@ def train(args): accelerator.print("NaN found in latents, replacing with zeros") latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - b_size = latents.shape[0] - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] - if not args.cache_text_encoder_outputs: + if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning # TODO support weighted captions @@ -395,8 +396,8 @@ def train(args): # else: input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states( - args, + encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( + args.max_token_length, input_ids1, input_ids2, tokenizer1, @@ -406,19 +407,26 @@ def train(args): None if not args.full_fp16 else weight_dtype, ) else: - encoder_hidden_states1 = [] - encoder_hidden_states2 = [] - pool2 = [] - for input_id1, input_id2 in zip(input_ids1, input_ids2): - input_id1_cache_key = tuple(input_id1.squeeze(0).flatten().tolist()) - input_id2_cache_key = tuple(input_id2.squeeze(0).flatten().tolist()) - encoder_hidden_states1.append(text_encoder1_cache[input_id1_cache_key]) - hidden_states2, p2 = text_encoder2_cache[input_id2_cache_key] - encoder_hidden_states2.append(hidden_states2) - pool2.append(p2) - encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype) - pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype) + encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) + encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) + pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + + # # verify that the text encoder outputs are correct + # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( + # args.max_token_length, + # batch["input_ids"].to(text_encoder1.device), + # batch["input_ids2"].to(text_encoder1.device), + # tokenizer1, + # tokenizer2, + # text_encoder1, + # text_encoder2, + # None if not args.full_fp16 else weight_dtype, + # ) + # b_size = encoder_hidden_states1.shape[0] + # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # print("text encoder outputs verified") # get size embeddings orig_size = batch["original_sizes_hw"] diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 0c3c0cc5..dc222534 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -47,7 +47,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): return args.cache_text_encoder_outputs def cache_text_encoder_outputs_if_needed( - self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset, weight_dtype + self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype ): if args.cache_text_encoder_outputs: if not args.lowram: @@ -60,34 +60,33 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): if torch.cuda.is_available(): torch.cuda.empty_cache() - text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs( - args, accelerator, tokenizers, text_encoders, dataset, weight_dtype + dataset.cache_text_encoder_outputs( + tokenizers, + text_encoders, + accelerator.device, + weight_dtype, + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, ) - accelerator.wait_for_everyone() - text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU + + text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) if torch.cuda.is_available(): torch.cuda.empty_cache() - self.text_encoder1_cache = text_encoder1_cache - self.text_encoder2_cache = text_encoder2_cache - if not args.lowram: print("move vae and unet back to original device") vae.to(org_vae_device) unet.to(org_unet_device) else: - self.text_encoder1_cache = None - self.text_encoder2_cache = None - # Text Encoderから毎回出力を取得するので、GPUに乗せておく text_encoders[0].to(accelerator.device) text_encoders[1].to(accelerator.device) def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] - if not args.cache_text_encoder_outputs: + if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] with torch.enable_grad(): # Get the text embedding for conditioning # TODO support weighted captions @@ -103,8 +102,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): # else: input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states( - args, + encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( + args.max_token_length, input_ids1, input_ids2, tokenizers[0], @@ -114,19 +113,27 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): None if not args.full_fp16 else weight_dtype, ) else: - encoder_hidden_states1 = [] - encoder_hidden_states2 = [] - pool2 = [] - for input_id1, input_id2 in zip(input_ids1, input_ids2): - input_id1_cache_key = tuple(input_id1.flatten().tolist()) - input_id2_cache_key = tuple(input_id2.flatten().tolist()) - encoder_hidden_states1.append(self.text_encoder1_cache[input_id1_cache_key]) - hidden_states2, p2 = self.text_encoder2_cache[input_id2_cache_key] - encoder_hidden_states2.append(hidden_states2) - pool2.append(p2) - encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype) - pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype) + encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) + encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) + pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + + # # verify that the text encoder outputs are correct + # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( + # args.max_token_length, + # batch["input_ids"].to(text_encoders[0].device), + # batch["input_ids2"].to(text_encoders[0].device), + # tokenizers[0], + # tokenizers[1], + # text_encoders[0], + # text_encoders[1], + # None if not args.full_fp16 else weight_dtype, + # ) + # b_size = encoder_hidden_states1.shape[0] + # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # print("text encoder outputs verified") + return encoder_hidden_states1, encoder_hidden_states2, pool2 diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 2616a22c..c649c9f4 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -45,8 +45,8 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine with torch.enable_grad(): input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states( - args, + encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( + args.max_token_length, input_ids1, input_ids2, tokenizers[0], diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py new file mode 100644 index 00000000..2110e726 --- /dev/null +++ b/tools/cache_text_encoder_outputs.py @@ -0,0 +1,191 @@ +# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance + +import argparse +import math +from multiprocessing import Value +import os + +from accelerate.utils import set_seed +import torch +from tqdm import tqdm + +from library import config_util +from library import train_util +from library import sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) + + +def cache_to_disk(args: argparse.Namespace) -> None: + train_util.prepare_dataset_args(args, True) + + # check cache arg + assert ( + args.cache_text_encoder_outputs_to_disk + ), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります" + + # できるだけ準備はしておくが今のところSDXLのみしか動かない + assert ( + args.sdxl + ), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です" + + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # tokenizerを準備する:datasetを動かすために必要 + if args.sdxl: + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenizers = [tokenizer1, tokenizer2] + else: + tokenizer = train_util.load_tokenizer(args) + tokenizers = [tokenizer] + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + print("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + print("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + # acceleratorを準備する + print("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, _ = train_util.prepare_dtype(args) + + # モデルを読み込む + print("load model") + if args.sdxl: + (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) + text_encoders = [text_encoder1, text_encoder2] + else: + text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) + text_encoders = [text_encoder1] + + for text_encoder in text_encoders: + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + + # dataloaderを準備する + train_dataset_group.set_caching_mode("text") + + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず + train_dataloader = accelerator.prepare(train_dataloader) + + # データ取得のためのループ + for batch in tqdm(train_dataloader): + absolute_paths = batch["absolute_paths"] + input_ids1_list = batch["input_ids1_list"] + input_ids2_list = batch["input_ids2_list"] + + image_infos = [] + for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list): + image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) + image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX + image_info + + if args.skip_existing: + if os.path.exists(image_info.text_encoder_outputs_npz): + print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") + continue + + image_info.input_ids1 = input_ids1 + image_info.input_ids2 = input_ids2 + image_infos.append(image_info) + + if len(image_infos) > 0: + b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos]) + b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos]) + train_util.cache_batch_text_encoder_outputs( + image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype + ) + + accelerator.wait_for_everyone() + accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_training_arguments(parser, True) + train_util.add_dataset_arguments(parser, True, True, True) + config_util.add_config_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") + parser.add_argument( + "--skip_existing", + action="store_true", + help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + cache_to_disk(args)