diff --git a/library/train_util.py b/library/train_util.py index a4021c4f..785dc0f9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -121,7 +121,7 @@ class ImageInfo: self.latents_flipped: torch.Tensor = None self.latents_npz: str = None self.latents_original_size: Tuple[int, int] = None # original image size, not latents size - self.latents_crop_left_top: Tuple[int, int] = None # original image crop left top, not latents crop left top + self.latents_crop_ltrb: Tuple[int, int] = None # crop left top right bottom in original pixel size, not latents size self.cond_img_path: str = None self.image: Optional[Image.Image] = None # optional, original PIL Image # SDXL, optional @@ -256,6 +256,26 @@ class BucketManager: ar_error = (reso[0] / reso[1]) - aspect_ratio return reso, resized_size, ar_error + @staticmethod + def get_crop_ltrb(bucket_reso: Tuple[int, int], image_size: Tuple[int, int]): + # Stability AIの前処理に合わせてcrop left/topを計算する。crop rightはflipのaugmentationのために求める + # Calculate crop left/top according to the preprocessing of Stability AI. Crop right is calculated for flip augmentation. + + bucket_ar = bucket_reso[0] / bucket_reso[1] + image_ar = image_size[0] / image_size[1] + if bucket_ar > image_ar: + # bucketのほうが横長→縦を合わせる + resized_width = bucket_reso[1] * image_ar + resized_height = bucket_reso[1] + else: + resized_width = bucket_reso[0] + resized_height = bucket_reso[0] / image_ar + crop_left = (bucket_reso[0] - resized_width) // 2 + crop_top = (bucket_reso[1] - resized_height) // 2 + crop_right = crop_left + resized_width + crop_bottom = crop_top + resized_height + return crop_left, crop_top, crop_right, crop_bottom + class BucketBatchIndex(NamedTuple): bucket_index: int @@ -1016,7 +1036,7 @@ class BaseDataset(torch.utils.data.Dataset): # image/latentsを処理する if image_info.latents is not None: # cache_latents=Trueの場合 original_size = image_info.latents_original_size - crop_left_top = image_info.latents_crop_left_top # calc values later if flipped + crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped if not flipped: latents = image_info.latents else: @@ -1024,7 +1044,7 @@ class BaseDataset(torch.utils.data.Dataset): image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_left_top, flipped_latents = load_latents_from_disk(image_info.latents_npz) + latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz) if flipped: latents = flipped_latents del flipped_latents @@ -1037,7 +1057,7 @@ class BaseDataset(torch.utils.data.Dataset): im_h, im_w = img.shape[0:2] if self.enable_bucket: - img, original_size, crop_left_top = trim_and_resize_if_required( + img, original_size, crop_ltrb = trim_and_resize_if_required( subset.random_crop, img, image_info.bucket_reso, image_info.resized_size ) else: @@ -1060,7 +1080,7 @@ class BaseDataset(torch.utils.data.Dataset): ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" original_size = [im_w, im_h] - crop_left_top = [0, 0] + crop_ltrb = (0, 0, 0, 0) # augmentation aug = self.aug_helper.get_augmentor(subset.color_aug) @@ -1078,8 +1098,11 @@ class BaseDataset(torch.utils.data.Dataset): target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) - if flipped: - crop_left_top = (original_size[0] - crop_left_top[0] - target_size[0], crop_left_top[1]) + if not flipped: + crop_left_top = (crop_ltrb[0], crop_ltrb[1]) + else: + # crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image + crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1]) original_sizes_hw.append((original_size[1], original_size[0])) crop_top_lefts.append((crop_left_top[1], crop_left_top[0])) @@ -1841,7 +1864,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): return False npz = np.load(npz_path) - if "latents" not in npz or "original_size" not in npz or "crop_left_top" not in npz: # old ver? + if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? return False if npz["latents"].shape[1:3] != expected_latents_size: return False @@ -1866,12 +1889,12 @@ def load_latents_from_disk( latents = npz["latents"] original_size = npz["original_size"].tolist() - crop_left_top = npz["crop_left_top"].tolist() + crop_ltrb = npz["crop_ltrb"].tolist() flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - return latents, original_size, crop_left_top, flipped_latents + return latents, original_size, crop_ltrb, flipped_latents -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_left_top, flipped_latents_tensor=None): +def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None): kwargs = {} if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() @@ -1879,7 +1902,7 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_left_top, npz_path, latents=latents_tensor.float().cpu().numpy(), original_size=np.array(original_size), - crop_left_top=np.array(crop_left_top), + crop_ltrb=np.array(crop_ltrb), **kwargs, ) @@ -1918,7 +1941,7 @@ def debug_dataset(train_dataset, show_input_ids=False): ) ): print( - f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop left top: {crptl}, target size: {trgsz}, flipped: {flpdz}' + f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' ) if show_input_ids: @@ -2063,35 +2086,37 @@ def load_image(image_path): return img -# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top) +# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) def trim_and_resize_if_required( random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int] -) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]: +) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: image_height, image_width = image.shape[0:2] + original_size = (image_width, image_height) # size before resize if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ image_height, image_width = image.shape[0:2] - original_size = (image_width, image_height) - crop_left_top = (0, 0) if image_width > reso[0]: trim_size = image_width - reso[0] p = trim_size // 2 if not random_crop else random.randint(0, trim_size) # print("w", trim_size, p) image = image[:, p : p + reso[0]] - crop_left_top = (p, 0) if image_height > reso[1]: trim_size = image_height - reso[1] p = trim_size // 2 if not random_crop else random.randint(0, trim_size) # print("h", trim_size, p) image = image[p : p + reso[1]] - crop_left_top = (crop_left_top[0], p) + + # random cropの場合のcropされた値をどうcrop left/topに反映するべきか全くアイデアがない + # I have no idea how to reflect the cropped value in crop left/top in the case of random crop + + crop_ltrb = BucketManager.get_crop_ltrb(reso, original_size) assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}" - return image, original_size, crop_left_top + return image, original_size, crop_ltrb def cache_batch_latents( @@ -2104,18 +2129,18 @@ def cache_batch_latents( flipped latents is also saved if flip_aug is True if cache_to_disk is False, set info.latents latents_flipped is also set if flip_aug is True - latents_original_size and latents_crop_left_top are also set + latents_original_size and latents_crop_ltrb are also set """ images = [] for info in image_infos: image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_left_top = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) image = IMAGE_TRANSFORMS(image) images.append(image) info.latents_original_size = original_size - info.latents_crop_left_top = crop_left_top + info.latents_crop_ltrb = crop_ltrb img_tensors = torch.stack(images, dim=0) img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) @@ -2136,7 +2161,7 @@ def cache_batch_latents( raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") if cache_to_disk: - save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_left_top, flipped_latent) + save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent) else: info.latents = latent if flip_aug: @@ -3348,7 +3373,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): """ name = args.lr_scheduler num_warmup_steps: Optional[int] = args.lr_warmup_steps - num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps + num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps num_cycles = args.lr_scheduler_num_cycles power = args.lr_scheduler_power