Use resize_image where resizing is required

This commit is contained in:
rockerBOO
2025-02-19 14:20:24 -05:00
parent ca1c129ffd
commit 7f2747176b
5 changed files with 113 additions and 71 deletions

View File

@@ -84,7 +84,7 @@ import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image
setup_logging()
import logging
@@ -1514,9 +1514,7 @@ class BaseDataset(torch.utils.data.Dataset):
nh = int(height * scale + 0.5)
nw = int(width * scale + 0.5)
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
interpolation = get_cv2_interpolation(subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation)
logger.info(f"Interpolation: {interpolation}")
image = cv2.resize(image, (nw, nh), interpolation=interpolation if interpolation is not None else cv2.INTER_AREA)
image = resize_image(image, width, height, nw, nh, subset.resize_interpolation)
face_cx = int(face_cx * scale + 0.5)
face_cy = int(face_cy * scale + 0.5)
height, width = nh, nw
@@ -2541,10 +2539,7 @@ class ControlNetDataset(BaseDataset):
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
interpolation = get_cv2_interpolation(self.resize_interpolation)
cond_img = cv2.resize(
cond_img, image_info.resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA
) # INTER_AREAでやりたいのでcv2でリサイズ
cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
# TODO support random crop
# 現在サポートしているcropはrandomではなく中央のみ
@@ -2558,7 +2553,7 @@ class ControlNetDataset(BaseDataset):
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# resize to target
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0])))
cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
if flipped:
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
@@ -2961,12 +2956,7 @@ def trim_and_resize_if_required(
original_size = (image_width, image_height) # size before resize
if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
if image_width > resized_size[0] and image_height > resized_size[1]:
interpolation = get_cv2_interpolation(resize_interpolation)
image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
else:
image = pil_resize(image, resized_size)
image = resize_image(image, image_width, image_height, resized_size[0], resized_size[1], resize_interpolation)
image_height, image_width = image.shape[0:2]
@@ -6566,28 +6556,3 @@ class LossRecorder:
return 0
return self.loss_total / losses
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
"""
Convert interpolation value to cv2 interpolation integer
"""
if interpolation is None:
return None
if interpolation == "lanczos":
return cv2.INTER_LANCZOS4
elif interpolation == "nearest":
return cv2.INTER_NEAREST
elif interpolation == "bilinear" or interpolation == "linear":
return cv2.INTER_LINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
return cv2.INTER_CUBIC
elif interpolation == "area":
return cv2.INTER_AREA
else:
return None
def validate_interpolation_fn(interpolation_str: str) -> bool:
"""
Check if a interpolation function is supported
"""
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"]