make crop top/left same as stabilityai's prep

This commit is contained in:
Kohya S
2023-07-18 21:39:36 +09:00
parent 3d66a234b0
commit 0ec7166098

View File

@@ -121,7 +121,7 @@ class ImageInfo:
self.latents_flipped: torch.Tensor = None
self.latents_npz: str = None
self.latents_original_size: Tuple[int, int] = None # original image size, not latents size
self.latents_crop_left_top: Tuple[int, int] = None # original image crop left top, not latents crop left top
self.latents_crop_ltrb: Tuple[int, int] = None # crop left top right bottom in original pixel size, not latents size
self.cond_img_path: str = None
self.image: Optional[Image.Image] = None # optional, original PIL Image
# SDXL, optional
@@ -256,6 +256,26 @@ class BucketManager:
ar_error = (reso[0] / reso[1]) - aspect_ratio
return reso, resized_size, ar_error
@staticmethod
def get_crop_ltrb(bucket_reso: Tuple[int, int], image_size: Tuple[int, int]):
# Stability AIの前処理に合わせてcrop left/topを計算する。crop rightはflipのaugmentationのために求める
# Calculate crop left/top according to the preprocessing of Stability AI. Crop right is calculated for flip augmentation.
bucket_ar = bucket_reso[0] / bucket_reso[1]
image_ar = image_size[0] / image_size[1]
if bucket_ar > image_ar:
# bucketのほうが横長→縦を合わせる
resized_width = bucket_reso[1] * image_ar
resized_height = bucket_reso[1]
else:
resized_width = bucket_reso[0]
resized_height = bucket_reso[0] / image_ar
crop_left = (bucket_reso[0] - resized_width) // 2
crop_top = (bucket_reso[1] - resized_height) // 2
crop_right = crop_left + resized_width
crop_bottom = crop_top + resized_height
return crop_left, crop_top, crop_right, crop_bottom
class BucketBatchIndex(NamedTuple):
bucket_index: int
@@ -1016,7 +1036,7 @@ class BaseDataset(torch.utils.data.Dataset):
# image/latentsを処理する
if image_info.latents is not None: # cache_latents=Trueの場合
original_size = image_info.latents_original_size
crop_left_top = image_info.latents_crop_left_top # calc values later if flipped
crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped
if not flipped:
latents = image_info.latents
else:
@@ -1024,7 +1044,7 @@ class BaseDataset(torch.utils.data.Dataset):
image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_left_top, flipped_latents = load_latents_from_disk(image_info.latents_npz)
latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz)
if flipped:
latents = flipped_latents
del flipped_latents
@@ -1037,7 +1057,7 @@ class BaseDataset(torch.utils.data.Dataset):
im_h, im_w = img.shape[0:2]
if self.enable_bucket:
img, original_size, crop_left_top = trim_and_resize_if_required(
img, original_size, crop_ltrb = trim_and_resize_if_required(
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
)
else:
@@ -1060,7 +1080,7 @@ class BaseDataset(torch.utils.data.Dataset):
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
original_size = [im_w, im_h]
crop_left_top = [0, 0]
crop_ltrb = (0, 0, 0, 0)
# augmentation
aug = self.aug_helper.get_augmentor(subset.color_aug)
@@ -1078,8 +1098,11 @@ class BaseDataset(torch.utils.data.Dataset):
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
if flipped:
crop_left_top = (original_size[0] - crop_left_top[0] - target_size[0], crop_left_top[1])
if not flipped:
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
else:
# crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image
crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1])
original_sizes_hw.append((original_size[1], original_size[0]))
crop_top_lefts.append((crop_left_top[1], crop_left_top[0]))
@@ -1841,7 +1864,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
return False
npz = np.load(npz_path)
if "latents" not in npz or "original_size" not in npz or "crop_left_top" not in npz: # old ver?
if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver?
return False
if npz["latents"].shape[1:3] != expected_latents_size:
return False
@@ -1866,12 +1889,12 @@ def load_latents_from_disk(
latents = npz["latents"]
original_size = npz["original_size"].tolist()
crop_left_top = npz["crop_left_top"].tolist()
crop_ltrb = npz["crop_ltrb"].tolist()
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
return latents, original_size, crop_left_top, flipped_latents
return latents, original_size, crop_ltrb, flipped_latents
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_left_top, flipped_latents_tensor=None):
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None):
kwargs = {}
if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
@@ -1879,7 +1902,7 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_left_top,
npz_path,
latents=latents_tensor.float().cpu().numpy(),
original_size=np.array(original_size),
crop_left_top=np.array(crop_left_top),
crop_ltrb=np.array(crop_ltrb),
**kwargs,
)
@@ -1918,7 +1941,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
)
):
print(
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop left top: {crptl}, target size: {trgsz}, flipped: {flpdz}'
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}'
)
if show_input_ids:
@@ -2063,35 +2086,37 @@ def load_image(image_path):
return img
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top)
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
def trim_and_resize_if_required(
random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int]
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
image_height, image_width = image.shape[0:2]
original_size = (image_width, image_height) # size before resize
if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
image_height, image_width = image.shape[0:2]
original_size = (image_width, image_height)
crop_left_top = (0, 0)
if image_width > reso[0]:
trim_size = image_width - reso[0]
p = trim_size // 2 if not random_crop else random.randint(0, trim_size)
# print("w", trim_size, p)
image = image[:, p : p + reso[0]]
crop_left_top = (p, 0)
if image_height > reso[1]:
trim_size = image_height - reso[1]
p = trim_size // 2 if not random_crop else random.randint(0, trim_size)
# print("h", trim_size, p)
image = image[p : p + reso[1]]
crop_left_top = (crop_left_top[0], p)
# random cropの場合のcropされた値をどうcrop left/topに反映するべきか全くアイデアがない
# I have no idea how to reflect the cropped value in crop left/top in the case of random crop
crop_ltrb = BucketManager.get_crop_ltrb(reso, original_size)
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
return image, original_size, crop_left_top
return image, original_size, crop_ltrb
def cache_batch_latents(
@@ -2104,18 +2129,18 @@ def cache_batch_latents(
flipped latents is also saved if flip_aug is True
if cache_to_disk is False, set info.latents
latents_flipped is also set if flip_aug is True
latents_original_size and latents_crop_left_top are also set
latents_original_size and latents_crop_ltrb are also set
"""
images = []
for info in image_infos:
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_left_top = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
image = IMAGE_TRANSFORMS(image)
images.append(image)
info.latents_original_size = original_size
info.latents_crop_left_top = crop_left_top
info.latents_crop_ltrb = crop_ltrb
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
@@ -2136,7 +2161,7 @@ def cache_batch_latents(
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
if cache_to_disk:
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_left_top, flipped_latent)
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent)
else:
info.latents = latent
if flip_aug:
@@ -3348,7 +3373,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
"""
name = args.lr_scheduler
num_warmup_steps: Optional[int] = args.lr_warmup_steps
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
num_cycles = args.lr_scheduler_num_cycles
power = args.lr_scheduler_power