mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Use resize_image where resizing is required
This commit is contained in:
@@ -11,7 +11,7 @@ from PIL import Image
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
from library.utils import setup_logging, pil_resize
|
from library.utils import setup_logging, resize_image
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
@@ -42,10 +42,7 @@ def preprocess_image(image):
|
|||||||
pad_t = pad_y // 2
|
pad_t = pad_y // 2
|
||||||
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
|
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
|
||||||
|
|
||||||
if size > IMAGE_SIZE:
|
image = resize_image(image, image.shape[0], image.shape[1], IMAGE_SIZE, IMAGE_SIZE)
|
||||||
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
|
|
||||||
else:
|
|
||||||
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))
|
|
||||||
|
|
||||||
image = image.astype(np.float32)
|
image = image.astype(np.float32)
|
||||||
return image
|
return image
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ import library.model_util as model_util
|
|||||||
import library.huggingface_util as huggingface_util
|
import library.huggingface_util as huggingface_util
|
||||||
import library.sai_model_spec as sai_model_spec
|
import library.sai_model_spec as sai_model_spec
|
||||||
import library.deepspeed_utils as deepspeed_utils
|
import library.deepspeed_utils as deepspeed_utils
|
||||||
from library.utils import setup_logging, pil_resize
|
from library.utils import setup_logging, resize_image
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
@@ -1514,9 +1514,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
nh = int(height * scale + 0.5)
|
nh = int(height * scale + 0.5)
|
||||||
nw = int(width * scale + 0.5)
|
nw = int(width * scale + 0.5)
|
||||||
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
|
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
|
||||||
interpolation = get_cv2_interpolation(subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation)
|
image = resize_image(image, width, height, nw, nh, subset.resize_interpolation)
|
||||||
logger.info(f"Interpolation: {interpolation}")
|
|
||||||
image = cv2.resize(image, (nw, nh), interpolation=interpolation if interpolation is not None else cv2.INTER_AREA)
|
|
||||||
face_cx = int(face_cx * scale + 0.5)
|
face_cx = int(face_cx * scale + 0.5)
|
||||||
face_cy = int(face_cy * scale + 0.5)
|
face_cy = int(face_cy * scale + 0.5)
|
||||||
height, width = nh, nw
|
height, width = nh, nw
|
||||||
@@ -2541,10 +2539,7 @@ class ControlNetDataset(BaseDataset):
|
|||||||
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
|
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
|
||||||
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
|
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
|
||||||
|
|
||||||
interpolation = get_cv2_interpolation(self.resize_interpolation)
|
cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
|
||||||
cond_img = cv2.resize(
|
|
||||||
cond_img, image_info.resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA
|
|
||||||
) # INTER_AREAでやりたいのでcv2でリサイズ
|
|
||||||
|
|
||||||
# TODO support random crop
|
# TODO support random crop
|
||||||
# 現在サポートしているcropはrandomではなく中央のみ
|
# 現在サポートしているcropはrandomではなく中央のみ
|
||||||
@@ -2558,7 +2553,7 @@ class ControlNetDataset(BaseDataset):
|
|||||||
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||||
# resize to target
|
# resize to target
|
||||||
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
|
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
|
||||||
cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0])))
|
cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
|
||||||
|
|
||||||
if flipped:
|
if flipped:
|
||||||
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
|
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
|
||||||
@@ -2961,12 +2956,7 @@ def trim_and_resize_if_required(
|
|||||||
original_size = (image_width, image_height) # size before resize
|
original_size = (image_width, image_height) # size before resize
|
||||||
|
|
||||||
if image_width != resized_size[0] or image_height != resized_size[1]:
|
if image_width != resized_size[0] or image_height != resized_size[1]:
|
||||||
# リサイズする
|
image = resize_image(image, image_width, image_height, resized_size[0], resized_size[1], resize_interpolation)
|
||||||
if image_width > resized_size[0] and image_height > resized_size[1]:
|
|
||||||
interpolation = get_cv2_interpolation(resize_interpolation)
|
|
||||||
image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
|
||||||
else:
|
|
||||||
image = pil_resize(image, resized_size)
|
|
||||||
|
|
||||||
image_height, image_width = image.shape[0:2]
|
image_height, image_width = image.shape[0:2]
|
||||||
|
|
||||||
@@ -6566,28 +6556,3 @@ class LossRecorder:
|
|||||||
return 0
|
return 0
|
||||||
return self.loss_total / losses
|
return self.loss_total / losses
|
||||||
|
|
||||||
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
|
|
||||||
"""
|
|
||||||
Convert interpolation value to cv2 interpolation integer
|
|
||||||
"""
|
|
||||||
if interpolation is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if interpolation == "lanczos":
|
|
||||||
return cv2.INTER_LANCZOS4
|
|
||||||
elif interpolation == "nearest":
|
|
||||||
return cv2.INTER_NEAREST
|
|
||||||
elif interpolation == "bilinear" or interpolation == "linear":
|
|
||||||
return cv2.INTER_LINEAR
|
|
||||||
elif interpolation == "bicubic" or interpolation == "cubic":
|
|
||||||
return cv2.INTER_CUBIC
|
|
||||||
elif interpolation == "area":
|
|
||||||
return cv2.INTER_AREA
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def validate_interpolation_fn(interpolation_str: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a interpolation function is supported
|
|
||||||
"""
|
|
||||||
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"]
|
|
||||||
|
|||||||
101
library/utils.py
101
library/utils.py
@@ -16,7 +16,6 @@ from PIL import Image
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
|
||||||
def fire_in_thread(f, *args, **kwargs):
|
def fire_in_thread(f, *args, **kwargs):
|
||||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||||
|
|
||||||
@@ -89,6 +88,8 @@ def setup_logging(args=None, log_level=None, reset=False):
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.info(msg_init)
|
logger.info(msg_init)
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
@@ -377,7 +378,7 @@ def load_safetensors(
|
|||||||
# region Image utils
|
# region Image utils
|
||||||
|
|
||||||
|
|
||||||
def pil_resize(image, size, interpolation=Image.LANCZOS):
|
def pil_resize(image, size, interpolation):
|
||||||
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
|
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
|
||||||
|
|
||||||
if has_alpha:
|
if has_alpha:
|
||||||
@@ -385,7 +386,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
|
|||||||
else:
|
else:
|
||||||
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||||
|
|
||||||
resized_pil = pil_image.resize(size, interpolation)
|
resized_pil = pil_image.resize(size, resample=interpolation)
|
||||||
|
|
||||||
# Convert back to cv2 format
|
# Convert back to cv2 format
|
||||||
if has_alpha:
|
if has_alpha:
|
||||||
@@ -396,6 +397,100 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
|
|||||||
return resized_cv2
|
return resized_cv2
|
||||||
|
|
||||||
|
|
||||||
|
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: numpy.ndarray
|
||||||
|
width: int Original image width
|
||||||
|
height: int Original image height
|
||||||
|
resized_width: int Resized image width
|
||||||
|
resized_height: int Resized image height
|
||||||
|
resize_interpolation: Optional[str] Resize interpolation method "lanczos", "area", "bilinear", "bicubic", "nearest", "box"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
image
|
||||||
|
"""
|
||||||
|
interpolation = get_cv2_interpolation(resize_interpolation)
|
||||||
|
resized_size = (resized_width, resized_height)
|
||||||
|
if width > resized_width and height > resized_width:
|
||||||
|
image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||||
|
logger.debug(f"resize image using {resize_interpolation}")
|
||||||
|
else:
|
||||||
|
image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_LANCZOS4) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||||
|
logger.debug(f"resize image using {resize_interpolation}")
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Convert interpolation value to cv2 interpolation integer
|
||||||
|
|
||||||
|
https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
|
||||||
|
"""
|
||||||
|
if interpolation is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if interpolation == "lanczos" or interpolation == "lanczos4":
|
||||||
|
# Lanczos interpolation over 8x8 neighborhood
|
||||||
|
return cv2.INTER_LANCZOS4
|
||||||
|
elif interpolation == "nearest":
|
||||||
|
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
|
||||||
|
return cv2.INTER_NEAREST_EXACT
|
||||||
|
elif interpolation == "bilinear" or interpolation == "linear":
|
||||||
|
# bilinear interpolation
|
||||||
|
return cv2.INTER_LINEAR
|
||||||
|
elif interpolation == "bicubic" or interpolation == "cubic":
|
||||||
|
# bicubic interpolation
|
||||||
|
return cv2.INTER_CUBIC
|
||||||
|
elif interpolation == "area":
|
||||||
|
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||||
|
return cv2.INTER_AREA
|
||||||
|
elif interpolation == "box":
|
||||||
|
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||||
|
return cv2.INTER_AREA
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
|
||||||
|
"""
|
||||||
|
Convert interpolation value to PIL interpolation
|
||||||
|
|
||||||
|
https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
|
||||||
|
"""
|
||||||
|
if interpolation is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if interpolation == "lanczos":
|
||||||
|
return Image.Resampling.LANCZOS
|
||||||
|
elif interpolation == "nearest":
|
||||||
|
# Pick one nearest pixel from the input image. Ignore all other input pixels.
|
||||||
|
return Image.Resampling.NEAREST
|
||||||
|
elif interpolation == "bilinear" or interpolation == "linear":
|
||||||
|
# For resize calculate the output pixel value using linear interpolation on all pixels that may contribute to the output value. For other transformations linear interpolation over a 2x2 environment in the input image is used.
|
||||||
|
return Image.Resampling.BILINEAR
|
||||||
|
elif interpolation == "bicubic" or interpolation == "cubic":
|
||||||
|
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
|
||||||
|
return Image.Resampling.BICUBIC
|
||||||
|
elif interpolation == "area":
|
||||||
|
# Image.Resampling.BOX may be more appropriate if upscaling
|
||||||
|
# Area interpolation is related to cv2.INTER_AREA
|
||||||
|
# Produces a sharper image than Resampling.BILINEAR, doesn’t have dislocations on local level like with Resampling.BOX.
|
||||||
|
return Image.Resampling.HAMMING
|
||||||
|
elif interpolation == "box":
|
||||||
|
# Each pixel of source image contributes to one pixel of the destination image with identical weights. For upscaling is equivalent of Resampling.NEAREST.
|
||||||
|
return Image.Resampling.BOX
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def validate_interpolation_fn(interpolation_str: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a interpolation function is supported
|
||||||
|
"""
|
||||||
|
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# TODO make inf_utils.py
|
# TODO make inf_utils.py
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import os
|
|||||||
from anime_face_detector import create_detector
|
from anime_face_detector import create_detector
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from library.utils import setup_logging, pil_resize
|
from library.utils import setup_logging, resize_image
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -170,12 +170,9 @@ def process(args):
|
|||||||
scale = max(cur_crop_width / w, cur_crop_height / h)
|
scale = max(cur_crop_width / w, cur_crop_height / h)
|
||||||
|
|
||||||
if scale != 1.0:
|
if scale != 1.0:
|
||||||
w = int(w * scale + .5)
|
rw = int(w * scale + .5)
|
||||||
h = int(h * scale + .5)
|
rh = int(h * scale + .5)
|
||||||
if scale < 1.0:
|
face_img = resize_image(face_img, w, h, rw, rh)
|
||||||
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
|
|
||||||
else:
|
|
||||||
face_img = pil_resize(face_img, (w, h))
|
|
||||||
cx = int(cx * scale + .5)
|
cx = int(cx * scale + .5)
|
||||||
cy = int(cy * scale + .5)
|
cy = int(cy * scale + .5)
|
||||||
fw = int(fw * scale + .5)
|
fw = int(fw * scale + .5)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import shutil
|
|||||||
import math
|
import math
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from library.utils import setup_logging, pil_resize
|
from library.utils import setup_logging, resize_image
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -22,14 +22,6 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
|||||||
if not os.path.exists(dst_img_folder):
|
if not os.path.exists(dst_img_folder):
|
||||||
os.makedirs(dst_img_folder)
|
os.makedirs(dst_img_folder)
|
||||||
|
|
||||||
# Select interpolation method
|
|
||||||
if interpolation == 'lanczos4':
|
|
||||||
pil_interpolation = Image.LANCZOS
|
|
||||||
elif interpolation == 'cubic':
|
|
||||||
pil_interpolation = Image.BICUBIC
|
|
||||||
else:
|
|
||||||
cv2_interpolation = cv2.INTER_AREA
|
|
||||||
|
|
||||||
# Iterate through all files in src_img_folder
|
# Iterate through all files in src_img_folder
|
||||||
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
|
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
|
||||||
for filename in os.listdir(src_img_folder):
|
for filename in os.listdir(src_img_folder):
|
||||||
@@ -63,11 +55,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
|||||||
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
||||||
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
||||||
|
|
||||||
# Resize image
|
img = resize_image(img, img.shape[0], img.shape[1], new_height, new_width, interpolation)
|
||||||
if cv2_interpolation:
|
|
||||||
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
|
|
||||||
else:
|
|
||||||
img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation)
|
|
||||||
else:
|
else:
|
||||||
new_height, new_width = img.shape[0:2]
|
new_height, new_width = img.shape[0:2]
|
||||||
|
|
||||||
@@ -113,8 +101,8 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
|
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
|
||||||
parser.add_argument('--divisible_by', type=int,
|
parser.add_argument('--divisible_by', type=int,
|
||||||
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
|
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
|
||||||
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
|
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4', 'nearest', 'linear', 'box'],
|
||||||
default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
|
default=None, help='Interpolation method for resizing. Default to area if smaller, lanczos if larger / サイズ変更の補間方法。小さい場合はデフォルトでエリア、大きい場合はランチョスになります。')
|
||||||
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
|
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
|
||||||
parser.add_argument('--copy_associated_files', action='store_true',
|
parser.add_argument('--copy_associated_files', action='store_true',
|
||||||
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
|
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
|
||||||
|
|||||||
Reference in New Issue
Block a user