From 9e9df2b5017c4dded9a7be1f46e916df248dd5c1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Jun 2023 17:56:02 +0900 Subject: [PATCH] update dataset to return size, refactor ctrlnet ds --- library/config_util.py | 2 +- library/sdxl_original_unet.py | 10 +- library/train_util.py | 625 ++++++++++++++++++---------------- 3 files changed, 333 insertions(+), 304 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 36c165a5..dd81ae66 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -79,7 +79,7 @@ class ControlNetSubsetParams(BaseSubsetParams): @dataclass class BaseDatasetParams: - tokenizer: CLIPTokenizer = None + tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None max_token_length: int = None resolution: Optional[Tuple[int, int]] = None debug_dataset: bool = False diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 6556b12c..fd37432f 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -1116,13 +1116,15 @@ if __name__ == "__main__": # 使用メモリ量確認用の疑似学習ループ print("preparing optimizer") - import bitsandbytes - import transformers + # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working + # import bitsandbytes # optimizer = bitsandbytes.adam.Adam8bit(unet.parameters(), lr=1e-3) # not working # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 - # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working + + import transformers + optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2 scaler = torch.cuda.amp.GradScaler(enabled=True) @@ -1133,7 +1135,7 @@ if __name__ == "__main__": for step in range(steps): print(f"step {step}") - x = torch.randn(batch_size, 4, 128, 128).cuda() # 512x512 + x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024 t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda") ctx = torch.randn(batch_size, 77, 2048).cuda() y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda() diff --git a/library/train_util.py b/library/train_util.py index 89ad683b..533bf0a9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -103,6 +103,9 @@ class ImageInfo: self.latents_flipped: torch.Tensor = None self.latents_npz: str = None self.latents_npz_flipped: str = None + self.latents_original_size: Tuple[int, int] = None # original image size, not latents size + self.latents_crop_left_top: Tuple[int, int] = None # original image crop left top, not latents crop left top + self.cond_img_path: str = None class BucketManager: @@ -171,6 +174,7 @@ class BucketManager: def select_bucket(self, image_width, image_height): aspect_ratio = image_width / image_height if not self.no_upscale: + # 拡大および縮小を行う # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する reso = (image_width, image_height) if reso in self.predefined_resos_set: @@ -189,6 +193,7 @@ class BucketManager: resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5)) # print("use predef", image_width, image_height, reso, resized_size) else: + # 縮小のみを行う if image_width * image_height > self.max_area: # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める resized_width = math.sqrt(self.max_area * aspect_ratio) @@ -238,41 +243,40 @@ class BucketBatchIndex(NamedTuple): class AugHelper: def __init__(self): # prepare all possible augmentators - color_aug_method = albu.OneOf( + self.color_aug_method = albu.OneOf( [ albu.HueSaturationValue(8, 0, 0, p=0.5), albu.RandomGamma((95, 105), p=0.5), ], p=0.33, ) - flip_aug_method = albu.HorizontalFlip(p=0.5) # key: (use_color_aug, use_flip_aug) - self.augmentors = { - (True, True): albu.Compose( - [ - color_aug_method, - flip_aug_method, - ], - p=1.0, - ), - (True, False): albu.Compose( - [ - color_aug_method, - ], - p=1.0, - ), - (False, True): albu.Compose( - [ - flip_aug_method, - ], - p=1.0, - ), - (False, False): None, - } + # self.augmentors = { + # (True, True): albu.Compose( + # [ + # color_aug_method, + # flip_aug_method, + # ], + # p=1.0, + # ), + # (True, False): albu.Compose( + # [ + # color_aug_method, + # ], + # p=1.0, + # ), + # (False, True): albu.Compose( + # [ + # flip_aug_method, + # ], + # p=1.0, + # ), + # (False, False): None, + # } - def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]: - return self.augmentors[(use_color_aug, use_flip_aug)] + def get_augmentor(self, use_color_aug: bool) -> Optional[albu.Compose]: + return self.color_aug_method if use_color_aug else None class BaseSubset: @@ -454,10 +458,16 @@ class ControlNetSubset(BaseSubset): class BaseDataset(torch.utils.data.Dataset): def __init__( - self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool + self, + tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]], + max_token_length: int, + resolution: Optional[Tuple[int, int]], + debug_dataset: bool, ) -> None: super().__init__() - self.tokenizer = tokenizer + + self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] + self.max_token_length = max_token_length # width/height is used when enable_bucket==False self.width, self.height = (None, None) if resolution is None else resolution @@ -478,7 +488,7 @@ class BaseDataset(torch.utils.data.Dataset): self.bucket_no_upscale = None self.bucket_info = None # for metadata - self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 + self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2 self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ @@ -594,48 +604,49 @@ class BaseDataset(torch.utils.data.Dataset): return caption - def get_input_ids(self, caption): - input_ids = self.tokenizer( + def get_input_ids(self, caption, tokenizer=None): + if tokenizer is None: + tokenizer = self.tokenizers[0] + + input_ids = tokenizer( caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt" ).input_ids - if self.tokenizer_max_length > self.tokenizer.model_max_length: + if self.tokenizer_max_length > tokenizer.model_max_length: input_ids = input_ids.squeeze(0) iids_list = [] - if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + if tokenizer.pad_token_id == tokenizer.eos_token_id: # v1 # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に for i in range( - 1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2 + 1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2 ): # (1, 152, 75) ids_chunk = ( input_ids[0].unsqueeze(0), - input_ids[i : i + self.tokenizer.model_max_length - 2], + input_ids[i : i + tokenizer.model_max_length - 2], input_ids[-1].unsqueeze(0), ) ids_chunk = torch.cat(ids_chunk) iids_list.append(ids_chunk) else: - # v2 + # v2 or SDXL # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する - for i in range( - 1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2 - ): + for i in range(1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): ids_chunk = ( input_ids[0].unsqueeze(0), # BOS - input_ids[i : i + self.tokenizer.model_max_length - 2], + input_ids[i : i + tokenizer.model_max_length - 2], input_ids[-1].unsqueeze(0), ) # PAD or EOS ids_chunk = torch.cat(ids_chunk) # 末尾が または の場合は、何もしなくてよい # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) - if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: - ids_chunk[-1] = self.tokenizer.eos_token_id + if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id: + ids_chunk[-1] = tokenizer.eos_token_id # 先頭が ... の場合は ... に変える - if ids_chunk[1] == self.tokenizer.pad_token_id: - ids_chunk[1] = self.tokenizer.eos_token_id + if ids_chunk[1] == tokenizer.pad_token_id: + ids_chunk[1] = tokenizer.eos_token_id iids_list.append(ids_chunk) @@ -755,46 +766,58 @@ class BaseDataset(torch.utils.data.Dataset): img = np.array(image, np.uint8) return img - def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size, cond_img = None): + def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size): image_height, image_width = image.shape[0:2] if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ - if exists(cond_img): - cond_img = cv2.resize(cond_img, resized_size, interpolation=cv2.INTER_AREA) image_height, image_width = image.shape[0:2] + original_size = (image_width, image_height) + + crop_left_top = (0, 0) if image_width > reso[0]: trim_size = image_width - reso[0] p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) # print("w", trim_size, p) image = image[:, p : p + reso[0]] - if exists(cond_img): - cond_img = cond_img[:, p : p + reso[0]] + crop_left_top = (p, 0) if image_height > reso[1]: trim_size = image_height - reso[1] p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) # print("h", trim_size, p) image = image[p : p + reso[1]] - if exists(cond_img): - cond_img = cond_img[p : p + reso[1]] + crop_left_top = (crop_left_top[0], p) assert ( image.shape[0] == reso[1] and image.shape[1] == reso[0] ), f"internal error, illegal trimmed size: {image.shape}, {reso}" - - if exists(cond_img): - assert ( - cond_img.shape[0] == reso[1] and cond_img.shape[1] == reso[0] - ), f"internal error, illegal trimmed size: {cond_img.shape}, {reso}" - return image, cond_img - - return image + return image, original_size, crop_left_top def is_latent_cacheable(self): return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) + def is_disk_cached_latents_is_expected(self, reso, npz_path, flipped_npz_path): + expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意 + + for npath in [npz_path, flipped_npz_path]: + if npath is None: + continue + if not os.path.exists(npath): + return False + + npz = np.load(npath) + if "latents" not in npz or "original_size" not in npz or "crop_left_top" not in npz: # old ver? + return False + + cached_latents = npz["latents"] + + if cached_latents.shape[1:3] != expected_latents_size: + return False + + return True + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): # ちょっと速くした print("caching latents.") @@ -811,38 +834,26 @@ class BaseDataset(torch.utils.data.Dataset): subset = self.image_to_subset[info.image_key] if info.latents_npz is not None: - info.latents = self.load_latents_from_npz(info, False) + info.latents, info.latents_original_size, info.latents_crop_left_top = self.load_latents_from_npz(info, False) info.latents = torch.FloatTensor(info.latents) - # might be None, but that's ok because check is done in dataset - info.latents_flipped = self.load_latents_from_npz(info, True) + info.latents_flipped, _, _ = self.load_latents_from_npz(info, True) # might be None if info.latents_flipped is not None: info.latents_flipped = torch.FloatTensor(info.latents_flipped) continue # check disk cache exists and size of latents if cache_to_disk: - # TODO: refactor to unify with FineTuningDataset info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz" if not is_main_process: continue - cache_available = False - expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意 - if os.path.exists(info.latents_npz): - cached_latents = np.load(info.latents_npz)["arr_0"] - if cached_latents.shape[1:3] == expected_latents_size: - cache_available = True + cache_available = self.is_disk_cached_latents_is_expected( + info.bucket_reso, info.latents_npz, info.latents_npz_flipped if self.flip_aug else None + ) - if subset.flip_aug: - cache_available = False - if os.path.exists(info.latents_npz_flipped): - cached_latents_flipped = np.load(info.latents_npz_flipped)["arr_0"] - if cached_latents_flipped.shape[1:3] == expected_latents_size: - cache_available = True - - if cache_available: + if cache_available: # do not add to batch continue # if last member of batch has different resolution, flush the batch @@ -868,10 +879,15 @@ class BaseDataset(torch.utils.data.Dataset): images = [] for info in batch: image = self.load_image(info.absolute_path) - image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size) + image, original_size, crop_left_top = self.trim_and_resize_if_required( + subset, image, info.bucket_reso, info.resized_size + ) image = self.image_transforms(image) images.append(image) + info.latents_original_size = original_size + info.latents_crop_left_top = crop_left_top + img_tensors = torch.stack(images, dim=0) img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) @@ -879,7 +895,12 @@ class BaseDataset(torch.utils.data.Dataset): for info, latent in zip(batch, latents): if cache_to_disk: - np.savez(info.latents_npz, latent.float().numpy()) + np.savez( + info.latents_npz, + latents=latent.float().numpy(), + original_size=np.array(info.latents_original_size), + crop_left_top=np.array(info.latents_crop_left_top), + ) else: info.latents = latent @@ -888,7 +909,12 @@ class BaseDataset(torch.utils.data.Dataset): latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") for info, latent in zip(batch, latents): if cache_to_disk: - np.savez(info.latents_npz_flipped, latent.float().numpy()) + np.savez( + info.latents_npz_flipped, + latents=latent.float().numpy(), + original_size=np.array(info.latents_original_size), + crop_left_top=np.array(info.latents_crop_left_top), # reverse horizontally when use flipped latents + ) else: info.latents_flipped = latent @@ -961,8 +987,13 @@ class BaseDataset(torch.utils.data.Dataset): def load_latents_from_npz(self, image_info: ImageInfo, flipped): npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz if npz_file is None: - return None - return np.load(npz_file)["arr_0"] + return None, None, None + + npz = np.load(npz_file) + latents = npz["latents"] + original_size = npz["original_size"].tolist() + crop_left_top = npz["crop_left_top"].tolist() + return latents, original_size, crop_left_top def __len__(self): return self._length @@ -975,21 +1006,35 @@ class BaseDataset(torch.utils.data.Dataset): loss_weights = [] captions = [] input_ids_list = [] + input_ids2_list = [] latents_list = [] images = [] + original_sizes_hw = [] + crop_top_lefts = [] + target_sizes_hw = [] + flippeds = [] # 変数名が微妙 for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance + # image/latentsを処理する if image_info.latents is not None: # cache_latents=Trueの場合 - latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped + original_size = image_info.latents_original_size + crop_left_top = image_info.latents_crop_left_top # calc values later if flipped + if not flipped: + latents = image_info.latents + else: + latents = image_info.latents_flipped + image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5) + latents, original_size, crop_left_top = self.load_latents_from_npz(image_info, flipped) latents = torch.FloatTensor(latents) + image = None else: # 画像を読み込み、必要ならcropする @@ -997,7 +1042,9 @@ class BaseDataset(torch.utils.data.Dataset): im_h, im_w = img.shape[0:2] if self.enable_bucket: - img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size) + img, original_size, crop_left_top = self.trim_and_resize_if_required( + subset, img, image_info.bucket_reso, image_info.resized_size + ) else: if face_cx > 0: # 顔位置情報あり img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h) @@ -1017,17 +1064,33 @@ class BaseDataset(torch.utils.data.Dataset): im_h == self.height and im_w == self.width ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + original_size = [im_w, im_h] + crop_left_top = [0, 0] + # augmentation - aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug) + aug = self.aug_helper.get_augmentor(subset.color_aug) if aug is not None: img = aug(image=img)["image"] + if flipped: + img = img[:, ::-1, :].copy() # copy to avoid negative stride problem + latents = None image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる images.append(image) latents_list.append(latents) + target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) + + if flipped: + crop_left_top = (original_size[0] - crop_left_top[0] - target_size[0], crop_left_top[1]) + + original_sizes_hw.append((original_size[1], original_size[0])) + crop_top_lefts.append((crop_left_top[1], crop_left_top[0])) + target_sizes_hw.append((target_size[1], target_size[0])) + flippeds.append(flipped) + caption = self.process_caption(subset, image_info.caption) if self.XTI_layers: caption_layer = [] @@ -1039,22 +1102,33 @@ class BaseDataset(torch.utils.data.Dataset): captions.append(caption_layer) else: captions.append(caption) + if not self.token_padding_disabled: # this option might be omitted in future if self.XTI_layers: - token_caption = self.get_input_ids(caption_layer) + token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) else: - token_caption = self.get_input_ids(caption) + token_caption = self.get_input_ids(caption, self.tokenizers[0]) input_ids_list.append(token_caption) + if len(self.tokenizers) > 1: + if self.XTI_layers: + token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) + else: + token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) + input_ids2_list.append(token_caption2) + example = {} example["loss_weights"] = torch.FloatTensor(loss_weights) if self.token_padding_disabled: # padding=True means pad in the batch - example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids + example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids + if len(self.tokenizers) > 1: + # following may not work in SDXL, keep the line for future update + example["input_ids2"] = self.tokenizer[1](captions, padding=True, truncation=True, return_tensors="pt").input_ids else: - # batch processing seems to be good example["input_ids"] = torch.stack(input_ids_list) + example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None if images[0] is not None: images = torch.stack(images) @@ -1066,6 +1140,11 @@ class BaseDataset(torch.utils.data.Dataset): example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None example["captions"] = captions + example["original_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in original_sizes_hw]) + example["crop_top_lefts"] = torch.stack([torch.LongTensor(x) for x in crop_top_lefts]) + example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw]) + example["flippeds"] = flippeds + if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] return example @@ -1462,151 +1541,86 @@ class ControlNetDataset(BaseDataset): max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, - debug_dataset) -> None: + debug_dataset, + ) -> None: super().__init__(tokenizer, max_token_length, resolution, debug_dataset) - self.conditioning_image_data: Dict[str, ImageInfo] = {} - assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" + db_subsets = [] + for subset in subsets: + db_subset = DreamBoothSubset( + subset.image_dir, + False, + None, + subset.caption_extension, + subset.num_repeats, + subset.shuffle_caption, + subset.keep_tokens, + subset.color_aug, + subset.flip_aug, + subset.face_crop_aug_range, + subset.random_crop, + subset.caption_dropout_rate, + subset.caption_dropout_every_n_epochs, + subset.caption_tag_dropout_rate, + subset.token_warmup_min, + subset.token_warmup_step, + ) + db_subsets.append(db_subset) + self.dreambooth_dataset_delegate = DreamBoothDataset( + db_subsets, + batch_size, + tokenizer, + max_token_length, + resolution, + enable_bucket, + min_bucket_reso, + max_bucket_reso, + bucket_reso_steps, + bucket_no_upscale, + 1.0, + debug_dataset, + ) + + # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) + self.image_data = self.dreambooth_dataset_delegate.image_data self.batch_size = batch_size - self.size = min(self.width, self.height) # 短いほう - self.latents_cache = None + self.num_train_images = self.dreambooth_dataset_delegate.num_train_images + self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images - self.num_reg_images = 0 - - self.enable_bucket = enable_bucket - if self.enable_bucket: - assert ( - min(resolution) >= min_bucket_reso - ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" - assert ( - max(resolution) <= max_bucket_reso - ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" - self.min_bucket_reso = min_bucket_reso - self.max_bucket_reso = max_bucket_reso - self.bucket_reso_steps = bucket_reso_steps - self.bucket_no_upscale = bucket_no_upscale - else: - self.min_bucket_reso = None - self.max_bucket_reso = None - self.bucket_reso_steps = None # この情報は使われない - self.bucket_no_upscale = False - - def read_caption(img_path, caption_extension): - # captionの候補ファイル名を作る - base_name = os.path.splitext(img_path)[0] - base_name_face_det = base_name - tokens = base_name.split("_") - if len(tokens) >= 5: - base_name_face_det = "_".join(tokens[:-4]) - cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] - - caption = None - for cap_path in cap_paths: - if os.path.isfile(cap_path): - with open(cap_path, "rt", encoding="utf-8") as f: - try: - lines = f.readlines() - except UnicodeDecodeError as e: - print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") - raise e - assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" - caption = lines[0].strip() + # assert all conditioning data exists + missing_imgs = [] + cond_imgs_with_img = set() + for image_key, info in self.dreambooth_dataset_delegate.image_data.items(): + db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key] + subset = None + for s in subsets: + if s.image_dir == db_subset.image_dir: + subset = s break - return caption + assert subset is not None, "internal error: subset not found" - def load_controlnet_dir(subset: ControlNetSubset): - if not os.path.isdir(subset.image_dir): - print(f"not directory: {subset.image_dir}") - return [], [] if not os.path.isdir(subset.conditioning_data_dir): print(f"not directory: {subset.conditioning_data_dir}") - return [], [] + continue - img_paths = glob_images(subset.image_dir, "*") - conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*") - img_paths = sorted(img_paths) - conditioning_img_paths = sorted(conditioning_img_paths) - print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") - print(f"found directory {subset.conditioning_data_dir} contains {len(conditioning_img_paths)} image files") + img_basename = os.path.basename(info.absolute_path) + ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename) + if not os.path.exists(ctrl_img_path): + missing_imgs.append(img_basename) - img_basenames = [os.path.basename(img) for img in img_paths] - conditioning_img_basenames = [os.path.basename(img) for img in conditioning_img_paths] - missing_imgs = [] - extra_imgs = [] + info.cond_img_path = ctrl_img_path + cond_imgs_with_img.add(ctrl_img_path) - for img in img_basenames: - if img not in conditioning_img_basenames: - missing_imgs.append(img) - for img in conditioning_img_basenames: - if img not in img_basenames: - extra_imgs.append(img) - - assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" - assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" - - - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う - captions = [] - missing_captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path, subset.caption_extension) - if cap_for_img is None: - print(f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}") - captions.append("") - missing_captions.append(img_path) - else: - captions.append(cap_for_img) - - self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 - - if missing_captions: - number_of_missing_captions = len(missing_captions) - number_of_missing_captions_to_show = 5 - remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show - - print( - f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。" - ) - for i, missing_caption in enumerate(missing_captions): - if i >= number_of_missing_captions_to_show: - print(missing_caption + f"... and {remaining_missing_captions} more") - break - print(missing_caption) - return img_paths, conditioning_img_paths, captions - - print("prepare images.") - num_train_images = 0 + extra_imgs = [] for subset in subsets: - if subset.num_repeats < 1: - print( - f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" - ) - continue + conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*") + extra_imgs.extend( + [cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img] + ) - if subset in self.subsets: - print( - f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" - ) - continue - - img_paths, conditioning_img_paths, captions = load_controlnet_dir(subset) - if len(img_paths) < 1: - print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") - continue - - num_train_images += subset.num_repeats * len(img_paths) - - for img_path, cond_img_path, caption in zip(img_paths, conditioning_img_paths, captions): - info = ImageInfo(img_path, subset.num_repeats, caption, False, img_path) - setattr(info, "cond_img_path", cond_img_path) - self.register_image(info, subset) - - subset.img_count = len(img_paths) - self.subsets.append(subset) - - print(f"{num_train_images} train images with repeating.") - self.num_train_images = num_train_images + assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" + assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" self.conditioning_image_transforms = transforms.Compose( [ @@ -1614,88 +1628,58 @@ class ControlNetDataset(BaseDataset): ] ) - def __getitem__(self, index): - bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index] - bucket_batch_size = self.buckets_indices[index].bucket_batch_size - image_index = self.buckets_indices[index].batch_index * bucket_batch_size + def make_buckets(self): + self.dreambooth_dataset_delegate.make_buckets() + self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager + self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices + + def __len__(self): + return self.dreambooth_dataset_delegate.__len__() + + def __getitem__(self, index): + example = self.dreambooth_dataset_delegate[index] + + bucket = self.dreambooth_dataset_delegate.bucket_manager.buckets[ + self.dreambooth_dataset_delegate.buckets_indices[index].bucket_index + ] + bucket_batch_size = self.dreambooth_dataset_delegate.buckets_indices[index].bucket_batch_size + image_index = self.dreambooth_dataset_delegate.buckets_indices[index].batch_index * bucket_batch_size - loss_weights = [] - captions = [] - input_ids_list = [] - latents_list = [] - images = [] conditioning_images = [] - for image_key in bucket[image_index : image_index + bucket_batch_size]: - image_info = self.image_data[image_key] - subset = self.image_to_subset[image_key] - loss_weights.append(1.0) + for i, image_key in enumerate(bucket[image_index : image_index + bucket_batch_size]): + image_info = self.dreambooth_dataset_delegate.image_data[image_key] - assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}" + target_size_hw = example["target_sizes_hw"][i] + original_size_hw = example["original_sizes_hw"][i] + crop_top_left = example["crop_top_lefts"][i] + flipped = example["flippeds"][i] + cond_img = self.load_image(image_info.cond_img_path) - # image/latentsを処理する - if image_info.latents is not None: # cache_latents=Trueの場合 - latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped - image = None - elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5) - latents = torch.FloatTensor(latents) - image = None + if self.dreambooth_dataset_delegate.enable_bucket: + cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + assert ( + cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] + ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" + ct, cl = crop_top_left + h, w = target_size_hw + cond_img = cond_img[ct : ct + h, cl : cl + w] else: - # 画像を読み込み、必要ならcropする - img = self.load_image(image_info.absolute_path) - cond_img = self.load_image(image_info.cond_img_path) - im_h, im_w = img.shape[0:2] + assert ( + cond_img.shape[0] == self.height and cond_img.shape[1] == self.width + ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" - if self.enable_bucket: - img, cond_img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size, cond_img=cond_img) - else: - im_h, im_w = img.shape[0:2] - assert ( - im_h == self.height and im_w == self.width - ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" - - # augmentation - aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug) - if aug is not None: - img = aug(image=img)["image"] - - latents = None - image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる - - images.append(image) - latents_list.append(latents) + if flipped: + cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride cond_img = self.conditioning_image_transforms(cond_img) conditioning_images.append(cond_img) - caption = self.process_caption(subset, image_info.caption) - captions.append(caption) - token_caption = self.get_input_ids(caption) - input_ids_list.append(token_caption) - - example = {} - example["loss_weights"] = torch.FloatTensor(loss_weights) - - example["input_ids"] = torch.stack(input_ids_list) - - if images[0] is not None: - images = torch.stack(images) - images = images.to(memory_format=torch.contiguous_format).float() - else: - images = None - example["images"] = images - - example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None - example["captions"] = captions - - if self.debug_dataset: - example["image_keys"] = bucket[image_index : image_index + self.batch_size] - example["conditioning_images"] = torch.stack(conditioning_images).to(memory_format=torch.contiguous_format).float() return example + # behave as Dataset mock class DatasetGroup(torch.utils.data.ConcatDataset): def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]): @@ -1773,18 +1757,42 @@ def debug_dataset(train_dataset, show_input_ids=False): example = train_dataset[idx] if example["latents"] is not None: print(f"sample has latents from npz file: {example['latents'].size()}") - for j, (ik, cap, lw, iid) in enumerate( - zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"]) + for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate( + zip( + example["image_keys"], + example["captions"], + example["loss_weights"], + example["input_ids"], + example["original_sizes_hw"], + example["crop_top_lefts"], + example["target_sizes_hw"], + example["flippeds"], + ) ): - print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') + print( + f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop left top: {crptl}, target size: {trgsz}, flipped: {flpdz}' + ) + if show_input_ids: print(f"input ids: {iid}") + if "input_ids2" in example: + print(f"input ids2: {example['input_ids2'][j]}") if example["images"] is not None: im = example["images"][j] print(f"image size: {im.size()}") im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c im = im[:, :, ::-1] # RGB -> BGR (OpenCV) + + if "conditioning_images" in example: + cond_img = example["conditioning_images"][j] + print(f"conditioning image size: {cond_img.size()}") + cond_img = (cond_img.numpy() * 255.0).astype(np.uint8) + cond_img = np.transpose(cond_img, (1, 2, 0)) + cond_img = cond_img[:, :, ::-1] + if os.name == "nt": + cv2.imshow("cond_img", cond_img) + if os.name == "nt": # only windows cv2.imshow("img", im) k = cv2.waitKey() @@ -2011,7 +2019,6 @@ def get_git_revision_hash() -> str: return "(unknown)" - # def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): # replace_attentions_for_hypernetwork() # # unet is not used currently, but it is here for future use @@ -2063,8 +2070,9 @@ def get_git_revision_hash() -> str: # out = self.to_out[1](out) # return out + # diffusers.models.attention.CrossAttention.forward = forward_xformers -def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers, sdpa): +def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: print("Enable memory efficient attention for U-Net") unet.set_use_memory_efficient_attention(False, True) @@ -2080,6 +2088,7 @@ def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers, sdpa print("Enable SDPA for U-Net") unet.set_use_sdpa(True) + """ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): # vae is not used currently, but it is here for future use @@ -2327,7 +2336,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") - parser.add_argument("--sdpa", action="store_true", help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)") + parser.add_argument( + "--sdpa", + action="store_true", + help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)", + ) parser.add_argument( "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" ) @@ -3231,7 +3244,9 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers if load_stable_diffusion_format: print(f"load StableDiffusion checkpoint: {name_or_path}") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2) + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint( + args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2 + ) else: # Diffusers model is loaded to CPU print(f"load Diffusers pretrained models: {name_or_path}") @@ -3281,7 +3296,10 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model( - args, weight_dtype, accelerator.device if args.lowram else "cpu", unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2 + args, + weight_dtype, + accelerator.device if args.lowram else "cpu", + unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2, ) # work on low-ram device @@ -3595,7 +3613,17 @@ SCHEDLER_SCHEDULE = "scaled_linear" def sample_images( - accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None, controlnet=None + accelerator, + args: argparse.Namespace, + epoch, + steps, + device, + vae, + tokenizer, + text_encoder, + unet, + prompt_replacement=None, + controlnet=None, ): """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した @@ -3690,7 +3718,7 @@ def sample_images( requires_safety_checker=False, clip_skip=args.clip_skip, ) - pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する + pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する pipeline.to(device) save_dir = args.output_dir + "/sample" @@ -3765,7 +3793,6 @@ def sample_images( controlnet_image = m.group(1) continue - except ValueError as ex: print(f"Exception in parsing / 解析エラー: {parg}") print(ex)