mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
make crop top/left same as stabilityai's prep
This commit is contained in:
@@ -121,7 +121,7 @@ class ImageInfo:
|
||||
self.latents_flipped: torch.Tensor = None
|
||||
self.latents_npz: str = None
|
||||
self.latents_original_size: Tuple[int, int] = None # original image size, not latents size
|
||||
self.latents_crop_left_top: Tuple[int, int] = None # original image crop left top, not latents crop left top
|
||||
self.latents_crop_ltrb: Tuple[int, int] = None # crop left top right bottom in original pixel size, not latents size
|
||||
self.cond_img_path: str = None
|
||||
self.image: Optional[Image.Image] = None # optional, original PIL Image
|
||||
# SDXL, optional
|
||||
@@ -256,6 +256,26 @@ class BucketManager:
|
||||
ar_error = (reso[0] / reso[1]) - aspect_ratio
|
||||
return reso, resized_size, ar_error
|
||||
|
||||
@staticmethod
|
||||
def get_crop_ltrb(bucket_reso: Tuple[int, int], image_size: Tuple[int, int]):
|
||||
# Stability AIの前処理に合わせてcrop left/topを計算する。crop rightはflipのaugmentationのために求める
|
||||
# Calculate crop left/top according to the preprocessing of Stability AI. Crop right is calculated for flip augmentation.
|
||||
|
||||
bucket_ar = bucket_reso[0] / bucket_reso[1]
|
||||
image_ar = image_size[0] / image_size[1]
|
||||
if bucket_ar > image_ar:
|
||||
# bucketのほうが横長→縦を合わせる
|
||||
resized_width = bucket_reso[1] * image_ar
|
||||
resized_height = bucket_reso[1]
|
||||
else:
|
||||
resized_width = bucket_reso[0]
|
||||
resized_height = bucket_reso[0] / image_ar
|
||||
crop_left = (bucket_reso[0] - resized_width) // 2
|
||||
crop_top = (bucket_reso[1] - resized_height) // 2
|
||||
crop_right = crop_left + resized_width
|
||||
crop_bottom = crop_top + resized_height
|
||||
return crop_left, crop_top, crop_right, crop_bottom
|
||||
|
||||
|
||||
class BucketBatchIndex(NamedTuple):
|
||||
bucket_index: int
|
||||
@@ -1016,7 +1036,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# image/latentsを処理する
|
||||
if image_info.latents is not None: # cache_latents=Trueの場合
|
||||
original_size = image_info.latents_original_size
|
||||
crop_left_top = image_info.latents_crop_left_top # calc values later if flipped
|
||||
crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped
|
||||
if not flipped:
|
||||
latents = image_info.latents
|
||||
else:
|
||||
@@ -1024,7 +1044,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
image = None
|
||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
latents, original_size, crop_left_top, flipped_latents = load_latents_from_disk(image_info.latents_npz)
|
||||
latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz)
|
||||
if flipped:
|
||||
latents = flipped_latents
|
||||
del flipped_latents
|
||||
@@ -1037,7 +1057,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
im_h, im_w = img.shape[0:2]
|
||||
|
||||
if self.enable_bucket:
|
||||
img, original_size, crop_left_top = trim_and_resize_if_required(
|
||||
img, original_size, crop_ltrb = trim_and_resize_if_required(
|
||||
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
|
||||
)
|
||||
else:
|
||||
@@ -1060,7 +1080,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
|
||||
original_size = [im_w, im_h]
|
||||
crop_left_top = [0, 0]
|
||||
crop_ltrb = (0, 0, 0, 0)
|
||||
|
||||
# augmentation
|
||||
aug = self.aug_helper.get_augmentor(subset.color_aug)
|
||||
@@ -1078,8 +1098,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
|
||||
|
||||
if flipped:
|
||||
crop_left_top = (original_size[0] - crop_left_top[0] - target_size[0], crop_left_top[1])
|
||||
if not flipped:
|
||||
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
|
||||
else:
|
||||
# crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image
|
||||
crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1])
|
||||
|
||||
original_sizes_hw.append((original_size[1], original_size[0]))
|
||||
crop_top_lefts.append((crop_left_top[1], crop_left_top[0]))
|
||||
@@ -1841,7 +1864,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
|
||||
return False
|
||||
|
||||
npz = np.load(npz_path)
|
||||
if "latents" not in npz or "original_size" not in npz or "crop_left_top" not in npz: # old ver?
|
||||
if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver?
|
||||
return False
|
||||
if npz["latents"].shape[1:3] != expected_latents_size:
|
||||
return False
|
||||
@@ -1866,12 +1889,12 @@ def load_latents_from_disk(
|
||||
|
||||
latents = npz["latents"]
|
||||
original_size = npz["original_size"].tolist()
|
||||
crop_left_top = npz["crop_left_top"].tolist()
|
||||
crop_ltrb = npz["crop_ltrb"].tolist()
|
||||
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
|
||||
return latents, original_size, crop_left_top, flipped_latents
|
||||
return latents, original_size, crop_ltrb, flipped_latents
|
||||
|
||||
|
||||
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_left_top, flipped_latents_tensor=None):
|
||||
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None):
|
||||
kwargs = {}
|
||||
if flipped_latents_tensor is not None:
|
||||
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
||||
@@ -1879,7 +1902,7 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_left_top,
|
||||
npz_path,
|
||||
latents=latents_tensor.float().cpu().numpy(),
|
||||
original_size=np.array(original_size),
|
||||
crop_left_top=np.array(crop_left_top),
|
||||
crop_ltrb=np.array(crop_ltrb),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1918,7 +1941,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
)
|
||||
):
|
||||
print(
|
||||
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop left top: {crptl}, target size: {trgsz}, flipped: {flpdz}'
|
||||
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}'
|
||||
)
|
||||
|
||||
if show_input_ids:
|
||||
@@ -2063,35 +2086,37 @@ def load_image(image_path):
|
||||
return img
|
||||
|
||||
|
||||
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top)
|
||||
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
|
||||
def trim_and_resize_if_required(
|
||||
random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int]
|
||||
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
|
||||
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
|
||||
image_height, image_width = image.shape[0:2]
|
||||
original_size = (image_width, image_height) # size before resize
|
||||
|
||||
if image_width != resized_size[0] or image_height != resized_size[1]:
|
||||
# リサイズする
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||
|
||||
image_height, image_width = image.shape[0:2]
|
||||
original_size = (image_width, image_height)
|
||||
|
||||
crop_left_top = (0, 0)
|
||||
if image_width > reso[0]:
|
||||
trim_size = image_width - reso[0]
|
||||
p = trim_size // 2 if not random_crop else random.randint(0, trim_size)
|
||||
# print("w", trim_size, p)
|
||||
image = image[:, p : p + reso[0]]
|
||||
crop_left_top = (p, 0)
|
||||
if image_height > reso[1]:
|
||||
trim_size = image_height - reso[1]
|
||||
p = trim_size // 2 if not random_crop else random.randint(0, trim_size)
|
||||
# print("h", trim_size, p)
|
||||
image = image[p : p + reso[1]]
|
||||
crop_left_top = (crop_left_top[0], p)
|
||||
|
||||
# random cropの場合のcropされた値をどうcrop left/topに反映するべきか全くアイデアがない
|
||||
# I have no idea how to reflect the cropped value in crop left/top in the case of random crop
|
||||
|
||||
crop_ltrb = BucketManager.get_crop_ltrb(reso, original_size)
|
||||
|
||||
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||
return image, original_size, crop_left_top
|
||||
return image, original_size, crop_ltrb
|
||||
|
||||
|
||||
def cache_batch_latents(
|
||||
@@ -2104,18 +2129,18 @@ def cache_batch_latents(
|
||||
flipped latents is also saved if flip_aug is True
|
||||
if cache_to_disk is False, set info.latents
|
||||
latents_flipped is also set if flip_aug is True
|
||||
latents_original_size and latents_crop_left_top are also set
|
||||
latents_original_size and latents_crop_ltrb are also set
|
||||
"""
|
||||
images = []
|
||||
for info in image_infos:
|
||||
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_left_top = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
||||
image = IMAGE_TRANSFORMS(image)
|
||||
images.append(image)
|
||||
|
||||
info.latents_original_size = original_size
|
||||
info.latents_crop_left_top = crop_left_top
|
||||
info.latents_crop_ltrb = crop_ltrb
|
||||
|
||||
img_tensors = torch.stack(images, dim=0)
|
||||
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
|
||||
@@ -2136,7 +2161,7 @@ def cache_batch_latents(
|
||||
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
||||
|
||||
if cache_to_disk:
|
||||
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_left_top, flipped_latent)
|
||||
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent)
|
||||
else:
|
||||
info.latents = latent
|
||||
if flip_aug:
|
||||
@@ -3348,7 +3373,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
"""
|
||||
name = args.lr_scheduler
|
||||
num_warmup_steps: Optional[int] = args.lr_warmup_steps
|
||||
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
|
||||
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
|
||||
num_cycles = args.lr_scheduler_num_cycles
|
||||
power = args.lr_scheduler_power
|
||||
|
||||
|
||||
Reference in New Issue
Block a user