Use resize_image where resizing is required

This commit is contained in:
rockerBOO
2025-02-19 14:20:24 -05:00
parent ca1c129ffd
commit 7f2747176b
5 changed files with 113 additions and 71 deletions

View File

@@ -11,7 +11,7 @@ from PIL import Image
from tqdm import tqdm from tqdm import tqdm
import library.train_util as train_util import library.train_util as train_util
from library.utils import setup_logging, pil_resize from library.utils import setup_logging, resize_image
setup_logging() setup_logging()
import logging import logging
@@ -42,10 +42,7 @@ def preprocess_image(image):
pad_t = pad_y // 2 pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
if size > IMAGE_SIZE: image = resize_image(image, image.shape[0], image.shape[1], IMAGE_SIZE, IMAGE_SIZE)
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
else:
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))
image = image.astype(np.float32) image = image.astype(np.float32)
return image return image

View File

@@ -84,7 +84,7 @@ import library.model_util as model_util
import library.huggingface_util as huggingface_util import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging, pil_resize from library.utils import setup_logging, resize_image
setup_logging() setup_logging()
import logging import logging
@@ -1514,9 +1514,7 @@ class BaseDataset(torch.utils.data.Dataset):
nh = int(height * scale + 0.5) nh = int(height * scale + 0.5)
nw = int(width * scale + 0.5) nw = int(width * scale + 0.5)
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
interpolation = get_cv2_interpolation(subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation) image = resize_image(image, width, height, nw, nh, subset.resize_interpolation)
logger.info(f"Interpolation: {interpolation}")
image = cv2.resize(image, (nw, nh), interpolation=interpolation if interpolation is not None else cv2.INTER_AREA)
face_cx = int(face_cx * scale + 0.5) face_cx = int(face_cx * scale + 0.5)
face_cy = int(face_cy * scale + 0.5) face_cy = int(face_cy * scale + 0.5)
height, width = nh, nw height, width = nh, nw
@@ -2541,10 +2539,7 @@ class ControlNetDataset(BaseDataset):
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
interpolation = get_cv2_interpolation(self.resize_interpolation) cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
cond_img = cv2.resize(
cond_img, image_info.resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA
) # INTER_AREAでやりたいのでcv2でリサイズ
# TODO support random crop # TODO support random crop
# 現在サポートしているcropはrandomではなく中央のみ # 現在サポートしているcropはrandomではなく中央のみ
@@ -2558,7 +2553,7 @@ class ControlNetDataset(BaseDataset):
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# resize to target # resize to target
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0]))) cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
if flipped: if flipped:
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
@@ -2961,12 +2956,7 @@ def trim_and_resize_if_required(
original_size = (image_width, image_height) # size before resize original_size = (image_width, image_height) # size before resize
if image_width != resized_size[0] or image_height != resized_size[1]: if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする image = resize_image(image, image_width, image_height, resized_size[0], resized_size[1], resize_interpolation)
if image_width > resized_size[0] and image_height > resized_size[1]:
interpolation = get_cv2_interpolation(resize_interpolation)
image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
else:
image = pil_resize(image, resized_size)
image_height, image_width = image.shape[0:2] image_height, image_width = image.shape[0:2]
@@ -6566,28 +6556,3 @@ class LossRecorder:
return 0 return 0
return self.loss_total / losses return self.loss_total / losses
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
"""
Convert interpolation value to cv2 interpolation integer
"""
if interpolation is None:
return None
if interpolation == "lanczos":
return cv2.INTER_LANCZOS4
elif interpolation == "nearest":
return cv2.INTER_NEAREST
elif interpolation == "bilinear" or interpolation == "linear":
return cv2.INTER_LINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
return cv2.INTER_CUBIC
elif interpolation == "area":
return cv2.INTER_AREA
else:
return None
def validate_interpolation_fn(interpolation_str: str) -> bool:
"""
Check if a interpolation function is supported
"""
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"]

View File

@@ -16,7 +16,6 @@ from PIL import Image
import numpy as np import numpy as np
from safetensors.torch import load_file from safetensors.torch import load_file
def fire_in_thread(f, *args, **kwargs): def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start() threading.Thread(target=f, args=args, kwargs=kwargs).start()
@@ -89,6 +88,8 @@ def setup_logging(args=None, log_level=None, reset=False):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.info(msg_init) logger.info(msg_init)
setup_logging()
logger = logging.getLogger(__name__)
# endregion # endregion
@@ -377,7 +378,7 @@ def load_safetensors(
# region Image utils # region Image utils
def pil_resize(image, size, interpolation=Image.LANCZOS): def pil_resize(image, size, interpolation):
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
if has_alpha: if has_alpha:
@@ -385,7 +386,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
else: else:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
resized_pil = pil_image.resize(size, interpolation) resized_pil = pil_image.resize(size, resample=interpolation)
# Convert back to cv2 format # Convert back to cv2 format
if has_alpha: if has_alpha:
@@ -396,6 +397,100 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
return resized_cv2 return resized_cv2
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
"""
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS
Args:
image: numpy.ndarray
width: int Original image width
height: int Original image height
resized_width: int Resized image width
resized_height: int Resized image height
resize_interpolation: Optional[str] Resize interpolation method "lanczos", "area", "bilinear", "bicubic", "nearest", "box"
Returns:
image
"""
interpolation = get_cv2_interpolation(resize_interpolation)
resized_size = (resized_width, resized_height)
if width > resized_width and height > resized_width:
image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
logger.debug(f"resize image using {resize_interpolation}")
else:
image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_LANCZOS4) # INTER_AREAでやりたいのでcv2でリサイズ
logger.debug(f"resize image using {resize_interpolation}")
return image
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
"""
Convert interpolation value to cv2 interpolation integer
https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
"""
if interpolation is None:
return None
if interpolation == "lanczos" or interpolation == "lanczos4":
# Lanczos interpolation over 8x8 neighborhood
return cv2.INTER_LANCZOS4
elif interpolation == "nearest":
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
return cv2.INTER_NEAREST_EXACT
elif interpolation == "bilinear" or interpolation == "linear":
# bilinear interpolation
return cv2.INTER_LINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
# bicubic interpolation
return cv2.INTER_CUBIC
elif interpolation == "area":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
elif interpolation == "box":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
else:
return None
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
"""
Convert interpolation value to PIL interpolation
https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
"""
if interpolation is None:
return None
if interpolation == "lanczos":
return Image.Resampling.LANCZOS
elif interpolation == "nearest":
# Pick one nearest pixel from the input image. Ignore all other input pixels.
return Image.Resampling.NEAREST
elif interpolation == "bilinear" or interpolation == "linear":
# For resize calculate the output pixel value using linear interpolation on all pixels that may contribute to the output value. For other transformations linear interpolation over a 2x2 environment in the input image is used.
return Image.Resampling.BILINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
return Image.Resampling.BICUBIC
elif interpolation == "area":
# Image.Resampling.BOX may be more appropriate if upscaling
# Area interpolation is related to cv2.INTER_AREA
# Produces a sharper image than Resampling.BILINEAR, doesnt have dislocations on local level like with Resampling.BOX.
return Image.Resampling.HAMMING
elif interpolation == "box":
# Each pixel of source image contributes to one pixel of the destination image with identical weights. For upscaling is equivalent of Resampling.NEAREST.
return Image.Resampling.BOX
else:
return None
def validate_interpolation_fn(interpolation_str: str) -> bool:
"""
Check if a interpolation function is supported
"""
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
# endregion # endregion
# TODO make inf_utils.py # TODO make inf_utils.py

View File

@@ -15,7 +15,7 @@ import os
from anime_face_detector import create_detector from anime_face_detector import create_detector
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
from library.utils import setup_logging, pil_resize from library.utils import setup_logging, resize_image
setup_logging() setup_logging()
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -170,12 +170,9 @@ def process(args):
scale = max(cur_crop_width / w, cur_crop_height / h) scale = max(cur_crop_width / w, cur_crop_height / h)
if scale != 1.0: if scale != 1.0:
w = int(w * scale + .5) rw = int(w * scale + .5)
h = int(h * scale + .5) rh = int(h * scale + .5)
if scale < 1.0: face_img = resize_image(face_img, w, h, rw, rh)
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
else:
face_img = pil_resize(face_img, (w, h))
cx = int(cx * scale + .5) cx = int(cx * scale + .5)
cy = int(cy * scale + .5) cy = int(cy * scale + .5)
fw = int(fw * scale + .5) fw = int(fw * scale + .5)

View File

@@ -6,7 +6,7 @@ import shutil
import math import math
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from library.utils import setup_logging, pil_resize from library.utils import setup_logging, resize_image
setup_logging() setup_logging()
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -22,14 +22,6 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
if not os.path.exists(dst_img_folder): if not os.path.exists(dst_img_folder):
os.makedirs(dst_img_folder) os.makedirs(dst_img_folder)
# Select interpolation method
if interpolation == 'lanczos4':
pil_interpolation = Image.LANCZOS
elif interpolation == 'cubic':
pil_interpolation = Image.BICUBIC
else:
cv2_interpolation = cv2.INTER_AREA
# Iterate through all files in src_img_folder # Iterate through all files in src_img_folder
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
for filename in os.listdir(src_img_folder): for filename in os.listdir(src_img_folder):
@@ -63,11 +55,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
new_height = int(img.shape[0] * math.sqrt(scale_factor)) new_height = int(img.shape[0] * math.sqrt(scale_factor))
new_width = int(img.shape[1] * math.sqrt(scale_factor)) new_width = int(img.shape[1] * math.sqrt(scale_factor))
# Resize image img = resize_image(img, img.shape[0], img.shape[1], new_height, new_width, interpolation)
if cv2_interpolation:
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
else:
img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation)
else: else:
new_height, new_width = img.shape[0:2] new_height, new_width = img.shape[0:2]
@@ -113,8 +101,8 @@ def setup_parser() -> argparse.ArgumentParser:
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128") help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
parser.add_argument('--divisible_by', type=int, parser.add_argument('--divisible_by', type=int,
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1) help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'], parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4', 'nearest', 'linear', 'box'],
default='area', help='Interpolation method for resizing / サイズの補方法') default=None, help='Interpolation method for resizing. Default to area if smaller, lanczos if larger / サイズ変更の補方法。小さい場合はデフォルトでエリア、大きい場合はランチョスになります。')
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存') parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
parser.add_argument('--copy_associated_files', action='store_true', parser.add_argument('--copy_associated_files', action='store_true',
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする') help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')