mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Use resize_image where resizing is required
This commit is contained in:
@@ -84,7 +84,7 @@ import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
import library.deepspeed_utils as deepspeed_utils
|
||||
from library.utils import setup_logging, pil_resize
|
||||
from library.utils import setup_logging, resize_image
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -1514,9 +1514,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
nh = int(height * scale + 0.5)
|
||||
nw = int(width * scale + 0.5)
|
||||
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
|
||||
interpolation = get_cv2_interpolation(subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation)
|
||||
logger.info(f"Interpolation: {interpolation}")
|
||||
image = cv2.resize(image, (nw, nh), interpolation=interpolation if interpolation is not None else cv2.INTER_AREA)
|
||||
image = resize_image(image, width, height, nw, nh, subset.resize_interpolation)
|
||||
face_cx = int(face_cx * scale + 0.5)
|
||||
face_cy = int(face_cy * scale + 0.5)
|
||||
height, width = nh, nw
|
||||
@@ -2541,10 +2539,7 @@ class ControlNetDataset(BaseDataset):
|
||||
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
|
||||
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
|
||||
|
||||
interpolation = get_cv2_interpolation(self.resize_interpolation)
|
||||
cond_img = cv2.resize(
|
||||
cond_img, image_info.resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA
|
||||
) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||
cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
|
||||
|
||||
# TODO support random crop
|
||||
# 現在サポートしているcropはrandomではなく中央のみ
|
||||
@@ -2558,7 +2553,7 @@ class ControlNetDataset(BaseDataset):
|
||||
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
# resize to target
|
||||
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
|
||||
cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0])))
|
||||
cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
|
||||
|
||||
if flipped:
|
||||
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
|
||||
@@ -2961,12 +2956,7 @@ def trim_and_resize_if_required(
|
||||
original_size = (image_width, image_height) # size before resize
|
||||
|
||||
if image_width != resized_size[0] or image_height != resized_size[1]:
|
||||
# リサイズする
|
||||
if image_width > resized_size[0] and image_height > resized_size[1]:
|
||||
interpolation = get_cv2_interpolation(resize_interpolation)
|
||||
image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||
else:
|
||||
image = pil_resize(image, resized_size)
|
||||
image = resize_image(image, image_width, image_height, resized_size[0], resized_size[1], resize_interpolation)
|
||||
|
||||
image_height, image_width = image.shape[0:2]
|
||||
|
||||
@@ -6566,28 +6556,3 @@ class LossRecorder:
|
||||
return 0
|
||||
return self.loss_total / losses
|
||||
|
||||
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
|
||||
"""
|
||||
Convert interpolation value to cv2 interpolation integer
|
||||
"""
|
||||
if interpolation is None:
|
||||
return None
|
||||
|
||||
if interpolation == "lanczos":
|
||||
return cv2.INTER_LANCZOS4
|
||||
elif interpolation == "nearest":
|
||||
return cv2.INTER_NEAREST
|
||||
elif interpolation == "bilinear" or interpolation == "linear":
|
||||
return cv2.INTER_LINEAR
|
||||
elif interpolation == "bicubic" or interpolation == "cubic":
|
||||
return cv2.INTER_CUBIC
|
||||
elif interpolation == "area":
|
||||
return cv2.INTER_AREA
|
||||
else:
|
||||
return None
|
||||
|
||||
def validate_interpolation_fn(interpolation_str: str) -> bool:
|
||||
"""
|
||||
Check if a interpolation function is supported
|
||||
"""
|
||||
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"]
|
||||
|
||||
Reference in New Issue
Block a user