diff --git a/library/train_util.py b/library/train_util.py index db596bc8..c1a5985d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1701,24 +1701,38 @@ class BaseDataset(torch.utils.data.Dataset): images.append(image) latents_list.append(None) alpha_mask_list.append(alpha_mask) + + target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) + + if not flipped: + crop_left_top = (crop_ltrb[0], crop_ltrb[1]) + else: + # crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image + crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1]) + + original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) + crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0]))) + target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) + flippeds.append(flipped) else: image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(subset, image_info, image_info.absolute_path, flipped) images.append(image) latents_list.append(None) alpha_mask_list.append(alpha_mask) - target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) + target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) - if not flipped: - crop_left_top = (crop_ltrb[0], crop_ltrb[1]) - else: - # crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image - crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1]) + if not flipped: + crop_left_top = (crop_ltrb[0], crop_ltrb[1]) + else: + # crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image + crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1]) + + original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) + crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0]))) + target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) + flippeds.append(flipped) - original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) - crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0]))) - target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) - flippeds.append(flipped) # captionとtext encoder outputを処理する caption = image_info.caption # default