diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index a327bbd6..6f5bdd36 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -11,7 +11,7 @@ from PIL import Image from tqdm import tqdm import library.train_util as train_util -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging @@ -42,8 +42,10 @@ def preprocess_image(image): pad_t = pad_y // 2 image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) - interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 - image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) + if size > IMAGE_SIZE: + image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA) + else: + image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE)) image = image.astype(np.float32) return image diff --git a/library/train_util.py b/library/train_util.py index 15c23f3c..160e3b44 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -71,7 +71,7 @@ import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec import library.deepspeed_utils as deepspeed_utils -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging @@ -2028,9 +2028,7 @@ class ControlNetDataset(BaseDataset): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img = cv2.resize( - cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4 - ) + cond_img=pil_resize(cond_img,(int(target_size_hw[1]), int(target_size_hw[0]))) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -2362,7 +2360,10 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + if image_width > resized_size[0] and image_height > resized_size[1]: + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + else: + image = pil_resize(image, resized_size) image_height, image_width = image.shape[0:2] diff --git a/library/utils.py b/library/utils.py index 3037c055..a219f6cb 100644 --- a/library/utils.py +++ b/library/utils.py @@ -7,7 +7,9 @@ from typing import * from diffusers import EulerAncestralDiscreteScheduler import diffusers.schedulers.scheduling_euler_ancestral_discrete from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput - +import cv2 +from PIL import Image +import numpy as np def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() @@ -78,7 +80,17 @@ def setup_logging(args=None, log_level=None, reset=False): logger = logging.getLogger(__name__) logger.info(msg_init) +def pil_resize(image, size, interpolation=Image.LANCZOS): + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # use Pillow resize + resized_pil = pil_image.resize(size, interpolation) + + # return cv2 image + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) + + return resized_cv2 # TODO make inf_utils.py diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py index bbc643ed..d2a4d9cf 100644 --- a/tools/detect_face_rotate.py +++ b/tools/detect_face_rotate.py @@ -15,7 +15,7 @@ import os from anime_face_detector import create_detector from tqdm import tqdm import numpy as np -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging logger = logging.getLogger(__name__) @@ -172,7 +172,10 @@ def process(args): if scale != 1.0: w = int(w * scale + .5) h = int(h * scale + .5) - face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4) + if scale < 1.0: + face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA) + else: + face_img = pil_resize(face_img, (w, h)) cx = int(cx * scale + .5) cy = int(cy * scale + .5) fw = int(fw * scale + .5) diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py index b8069fc1..0f9e00b1 100644 --- a/tools/resize_images_to_resolution.py +++ b/tools/resize_images_to_resolution.py @@ -6,7 +6,7 @@ import shutil import math from PIL import Image import numpy as np -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging logger = logging.getLogger(__name__) @@ -24,9 +24,9 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi # Select interpolation method if interpolation == 'lanczos4': - cv2_interpolation = cv2.INTER_LANCZOS4 + pil_interpolation = Image.LANCZOS elif interpolation == 'cubic': - cv2_interpolation = cv2.INTER_CUBIC + pil_interpolation = Image.BICUBIC else: cv2_interpolation = cv2.INTER_AREA @@ -64,7 +64,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi new_width = int(img.shape[1] * math.sqrt(scale_factor)) # Resize image - img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) + if cv2_interpolation: + img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) + else: + img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation) else: new_height, new_width = img.shape[0:2]