Use resize_image where resizing is required

This commit is contained in:
rockerBOO
2025-02-19 14:20:24 -05:00
parent ca1c129ffd
commit 7f2747176b
5 changed files with 113 additions and 71 deletions

View File

@@ -84,7 +84,7 @@ import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image
setup_logging()
import logging
@@ -1514,9 +1514,7 @@ class BaseDataset(torch.utils.data.Dataset):
nh = int(height * scale + 0.5)
nw = int(width * scale + 0.5)
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
interpolation = get_cv2_interpolation(subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation)
logger.info(f"Interpolation: {interpolation}")
image = cv2.resize(image, (nw, nh), interpolation=interpolation if interpolation is not None else cv2.INTER_AREA)
image = resize_image(image, width, height, nw, nh, subset.resize_interpolation)
face_cx = int(face_cx * scale + 0.5)
face_cy = int(face_cy * scale + 0.5)
height, width = nh, nw
@@ -2541,10 +2539,7 @@ class ControlNetDataset(BaseDataset):
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
interpolation = get_cv2_interpolation(self.resize_interpolation)
cond_img = cv2.resize(
cond_img, image_info.resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA
) # INTER_AREAでやりたいのでcv2でリサイズ
cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
# TODO support random crop
# 現在サポートしているcropはrandomではなく中央のみ
@@ -2558,7 +2553,7 @@ class ControlNetDataset(BaseDataset):
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# resize to target
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0])))
cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
if flipped:
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
@@ -2961,12 +2956,7 @@ def trim_and_resize_if_required(
original_size = (image_width, image_height) # size before resize
if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
if image_width > resized_size[0] and image_height > resized_size[1]:
interpolation = get_cv2_interpolation(resize_interpolation)
image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
else:
image = pil_resize(image, resized_size)
image = resize_image(image, image_width, image_height, resized_size[0], resized_size[1], resize_interpolation)
image_height, image_width = image.shape[0:2]
@@ -6566,28 +6556,3 @@ class LossRecorder:
return 0
return self.loss_total / losses
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
"""
Convert interpolation value to cv2 interpolation integer
"""
if interpolation is None:
return None
if interpolation == "lanczos":
return cv2.INTER_LANCZOS4
elif interpolation == "nearest":
return cv2.INTER_NEAREST
elif interpolation == "bilinear" or interpolation == "linear":
return cv2.INTER_LINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
return cv2.INTER_CUBIC
elif interpolation == "area":
return cv2.INTER_AREA
else:
return None
def validate_interpolation_fn(interpolation_str: str) -> bool:
"""
Check if a interpolation function is supported
"""
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"]

View File

@@ -16,7 +16,6 @@ from PIL import Image
import numpy as np
from safetensors.torch import load_file
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
@@ -89,6 +88,8 @@ def setup_logging(args=None, log_level=None, reset=False):
logger = logging.getLogger(__name__)
logger.info(msg_init)
setup_logging()
logger = logging.getLogger(__name__)
# endregion
@@ -377,7 +378,7 @@ def load_safetensors(
# region Image utils
def pil_resize(image, size, interpolation=Image.LANCZOS):
def pil_resize(image, size, interpolation):
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
if has_alpha:
@@ -385,7 +386,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
else:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
resized_pil = pil_image.resize(size, interpolation)
resized_pil = pil_image.resize(size, resample=interpolation)
# Convert back to cv2 format
if has_alpha:
@@ -396,6 +397,100 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
return resized_cv2
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
"""
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS
Args:
image: numpy.ndarray
width: int Original image width
height: int Original image height
resized_width: int Resized image width
resized_height: int Resized image height
resize_interpolation: Optional[str] Resize interpolation method "lanczos", "area", "bilinear", "bicubic", "nearest", "box"
Returns:
image
"""
interpolation = get_cv2_interpolation(resize_interpolation)
resized_size = (resized_width, resized_height)
if width > resized_width and height > resized_width:
image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
logger.debug(f"resize image using {resize_interpolation}")
else:
image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_LANCZOS4) # INTER_AREAでやりたいのでcv2でリサイズ
logger.debug(f"resize image using {resize_interpolation}")
return image
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
"""
Convert interpolation value to cv2 interpolation integer
https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
"""
if interpolation is None:
return None
if interpolation == "lanczos" or interpolation == "lanczos4":
# Lanczos interpolation over 8x8 neighborhood
return cv2.INTER_LANCZOS4
elif interpolation == "nearest":
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
return cv2.INTER_NEAREST_EXACT
elif interpolation == "bilinear" or interpolation == "linear":
# bilinear interpolation
return cv2.INTER_LINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
# bicubic interpolation
return cv2.INTER_CUBIC
elif interpolation == "area":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
elif interpolation == "box":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
else:
return None
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
"""
Convert interpolation value to PIL interpolation
https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
"""
if interpolation is None:
return None
if interpolation == "lanczos":
return Image.Resampling.LANCZOS
elif interpolation == "nearest":
# Pick one nearest pixel from the input image. Ignore all other input pixels.
return Image.Resampling.NEAREST
elif interpolation == "bilinear" or interpolation == "linear":
# For resize calculate the output pixel value using linear interpolation on all pixels that may contribute to the output value. For other transformations linear interpolation over a 2x2 environment in the input image is used.
return Image.Resampling.BILINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
return Image.Resampling.BICUBIC
elif interpolation == "area":
# Image.Resampling.BOX may be more appropriate if upscaling
# Area interpolation is related to cv2.INTER_AREA
# Produces a sharper image than Resampling.BILINEAR, doesnt have dislocations on local level like with Resampling.BOX.
return Image.Resampling.HAMMING
elif interpolation == "box":
# Each pixel of source image contributes to one pixel of the destination image with identical weights. For upscaling is equivalent of Resampling.NEAREST.
return Image.Resampling.BOX
else:
return None
def validate_interpolation_fn(interpolation_str: str) -> bool:
"""
Check if a interpolation function is supported
"""
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
# endregion
# TODO make inf_utils.py