mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add resize interpolation configuration
This commit is contained in:
@@ -75,6 +75,7 @@ class BaseSubsetParams:
|
|||||||
custom_attributes: Optional[Dict[str, Any]] = None
|
custom_attributes: Optional[Dict[str, Any]] = None
|
||||||
validation_seed: int = 0
|
validation_seed: int = 0
|
||||||
validation_split: float = 0.0
|
validation_split: float = 0.0
|
||||||
|
resize_interpolation: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -106,7 +107,7 @@ class BaseDatasetParams:
|
|||||||
debug_dataset: bool = False
|
debug_dataset: bool = False
|
||||||
validation_seed: Optional[int] = None
|
validation_seed: Optional[int] = None
|
||||||
validation_split: float = 0.0
|
validation_split: float = 0.0
|
||||||
|
resize_interpolation: Optional[str] = None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DreamBoothDatasetParams(BaseDatasetParams):
|
class DreamBoothDatasetParams(BaseDatasetParams):
|
||||||
@@ -196,6 +197,7 @@ class ConfigSanitizer:
|
|||||||
"caption_prefix": str,
|
"caption_prefix": str,
|
||||||
"caption_suffix": str,
|
"caption_suffix": str,
|
||||||
"custom_attributes": dict,
|
"custom_attributes": dict,
|
||||||
|
"resize_interpolation": str,
|
||||||
}
|
}
|
||||||
# DO means DropOut
|
# DO means DropOut
|
||||||
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
||||||
@@ -241,6 +243,7 @@ class ConfigSanitizer:
|
|||||||
"validation_split": float,
|
"validation_split": float,
|
||||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||||
"network_multiplier": float,
|
"network_multiplier": float,
|
||||||
|
"resize_interpolation": str,
|
||||||
}
|
}
|
||||||
|
|
||||||
# options handled by argparse but not handled by user config
|
# options handled by argparse but not handled by user config
|
||||||
@@ -525,6 +528,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
|||||||
[{dataset_type} {i}]
|
[{dataset_type} {i}]
|
||||||
batch_size: {dataset.batch_size}
|
batch_size: {dataset.batch_size}
|
||||||
resolution: {(dataset.width, dataset.height)}
|
resolution: {(dataset.width, dataset.height)}
|
||||||
|
resize_interpolation: {dataset.resize_interpolation}
|
||||||
enable_bucket: {dataset.enable_bucket}
|
enable_bucket: {dataset.enable_bucket}
|
||||||
""")
|
""")
|
||||||
|
|
||||||
@@ -558,6 +562,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
|||||||
token_warmup_min: {subset.token_warmup_min},
|
token_warmup_min: {subset.token_warmup_min},
|
||||||
token_warmup_step: {subset.token_warmup_step},
|
token_warmup_step: {subset.token_warmup_step},
|
||||||
alpha_mask: {subset.alpha_mask}
|
alpha_mask: {subset.alpha_mask}
|
||||||
|
resize_interpolation: {subset.resize_interpolation}
|
||||||
custom_attributes: {subset.custom_attributes}
|
custom_attributes: {subset.custom_attributes}
|
||||||
"""), " ")
|
"""), " ")
|
||||||
|
|
||||||
|
|||||||
@@ -210,6 +210,7 @@ class ImageInfo:
|
|||||||
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||||
|
self.resize_interpolation: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class BucketManager:
|
class BucketManager:
|
||||||
@@ -434,6 +435,7 @@ class BaseSubset:
|
|||||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||||
validation_seed: Optional[int] = None,
|
validation_seed: Optional[int] = None,
|
||||||
validation_split: Optional[float] = 0.0,
|
validation_split: Optional[float] = 0.0,
|
||||||
|
resize_interpolation: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.image_dir = image_dir
|
self.image_dir = image_dir
|
||||||
self.alpha_mask = alpha_mask if alpha_mask is not None else False
|
self.alpha_mask = alpha_mask if alpha_mask is not None else False
|
||||||
@@ -464,6 +466,8 @@ class BaseSubset:
|
|||||||
self.validation_seed = validation_seed
|
self.validation_seed = validation_seed
|
||||||
self.validation_split = validation_split
|
self.validation_split = validation_split
|
||||||
|
|
||||||
|
self.resize_interpolation = resize_interpolation
|
||||||
|
|
||||||
|
|
||||||
class DreamBoothSubset(BaseSubset):
|
class DreamBoothSubset(BaseSubset):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -495,6 +499,7 @@ class DreamBoothSubset(BaseSubset):
|
|||||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||||
validation_seed: Optional[int] = None,
|
validation_seed: Optional[int] = None,
|
||||||
validation_split: Optional[float] = 0.0,
|
validation_split: Optional[float] = 0.0,
|
||||||
|
resize_interpolation: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||||
|
|
||||||
@@ -522,6 +527,7 @@ class DreamBoothSubset(BaseSubset):
|
|||||||
custom_attributes=custom_attributes,
|
custom_attributes=custom_attributes,
|
||||||
validation_seed=validation_seed,
|
validation_seed=validation_seed,
|
||||||
validation_split=validation_split,
|
validation_split=validation_split,
|
||||||
|
resize_interpolation=resize_interpolation,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.is_reg = is_reg
|
self.is_reg = is_reg
|
||||||
@@ -564,6 +570,7 @@ class FineTuningSubset(BaseSubset):
|
|||||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||||
validation_seed: Optional[int] = None,
|
validation_seed: Optional[int] = None,
|
||||||
validation_split: Optional[float] = 0.0,
|
validation_split: Optional[float] = 0.0,
|
||||||
|
resize_interpolation: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||||
|
|
||||||
@@ -591,6 +598,7 @@ class FineTuningSubset(BaseSubset):
|
|||||||
custom_attributes=custom_attributes,
|
custom_attributes=custom_attributes,
|
||||||
validation_seed=validation_seed,
|
validation_seed=validation_seed,
|
||||||
validation_split=validation_split,
|
validation_split=validation_split,
|
||||||
|
resize_interpolation=resize_interpolation,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.metadata_file = metadata_file
|
self.metadata_file = metadata_file
|
||||||
@@ -629,6 +637,7 @@ class ControlNetSubset(BaseSubset):
|
|||||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||||
validation_seed: Optional[int] = None,
|
validation_seed: Optional[int] = None,
|
||||||
validation_split: Optional[float] = 0.0,
|
validation_split: Optional[float] = 0.0,
|
||||||
|
resize_interpolation: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||||
|
|
||||||
@@ -656,6 +665,7 @@ class ControlNetSubset(BaseSubset):
|
|||||||
custom_attributes=custom_attributes,
|
custom_attributes=custom_attributes,
|
||||||
validation_seed=validation_seed,
|
validation_seed=validation_seed,
|
||||||
validation_split=validation_split,
|
validation_split=validation_split,
|
||||||
|
resize_interpolation=resize_interpolation,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conditioning_data_dir = conditioning_data_dir
|
self.conditioning_data_dir = conditioning_data_dir
|
||||||
@@ -676,6 +686,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
resolution: Optional[Tuple[int, int]],
|
resolution: Optional[Tuple[int, int]],
|
||||||
network_multiplier: float,
|
network_multiplier: float,
|
||||||
debug_dataset: bool,
|
debug_dataset: bool,
|
||||||
|
resize_interpolation: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -710,6 +721,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.image_transforms = IMAGE_TRANSFORMS
|
self.image_transforms = IMAGE_TRANSFORMS
|
||||||
|
|
||||||
|
if resize_interpolation is not None:
|
||||||
|
assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation"
|
||||||
|
self.resize_interpolation = resize_interpolation
|
||||||
|
|
||||||
self.image_data: Dict[str, ImageInfo] = {}
|
self.image_data: Dict[str, ImageInfo] = {}
|
||||||
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
||||||
|
|
||||||
@@ -1499,7 +1514,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
nh = int(height * scale + 0.5)
|
nh = int(height * scale + 0.5)
|
||||||
nw = int(width * scale + 0.5)
|
nw = int(width * scale + 0.5)
|
||||||
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
|
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
|
||||||
image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
|
interpolation = get_cv2_interpolation(subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation)
|
||||||
|
logger.info(f"Interpolation: {interpolation}")
|
||||||
|
image = cv2.resize(image, (nw, nh), interpolation=interpolation if interpolation is not None else cv2.INTER_AREA)
|
||||||
face_cx = int(face_cx * scale + 0.5)
|
face_cx = int(face_cx * scale + 0.5)
|
||||||
face_cy = int(face_cy * scale + 0.5)
|
face_cy = int(face_cy * scale + 0.5)
|
||||||
height, width = nh, nw
|
height, width = nh, nw
|
||||||
@@ -1596,7 +1613,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
if self.enable_bucket:
|
if self.enable_bucket:
|
||||||
img, original_size, crop_ltrb = trim_and_resize_if_required(
|
img, original_size, crop_ltrb = trim_and_resize_if_required(
|
||||||
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
|
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if face_cx > 0: # 顔位置情報あり
|
if face_cx > 0: # 顔位置情報あり
|
||||||
@@ -1857,8 +1874,9 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
debug_dataset: bool,
|
debug_dataset: bool,
|
||||||
validation_split: float,
|
validation_split: float,
|
||||||
validation_seed: Optional[int],
|
validation_seed: Optional[int],
|
||||||
|
resize_interpolation: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||||
|
|
||||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||||
|
|
||||||
@@ -2087,6 +2105,7 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
|
|
||||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
|
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
|
||||||
|
info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||||
if size is not None:
|
if size is not None:
|
||||||
info.image_size = size
|
info.image_size = size
|
||||||
if subset.is_reg:
|
if subset.is_reg:
|
||||||
@@ -2370,8 +2389,9 @@ class ControlNetDataset(BaseDataset):
|
|||||||
debug_dataset: bool,
|
debug_dataset: bool,
|
||||||
validation_split: float,
|
validation_split: float,
|
||||||
validation_seed: Optional[int],
|
validation_seed: Optional[int],
|
||||||
|
resize_interpolation: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||||
|
|
||||||
db_subsets = []
|
db_subsets = []
|
||||||
for subset in subsets:
|
for subset in subsets:
|
||||||
@@ -2403,6 +2423,7 @@ class ControlNetDataset(BaseDataset):
|
|||||||
subset.caption_suffix,
|
subset.caption_suffix,
|
||||||
subset.token_warmup_min,
|
subset.token_warmup_min,
|
||||||
subset.token_warmup_step,
|
subset.token_warmup_step,
|
||||||
|
resize_interpolation=subset.resize_interpolation,
|
||||||
)
|
)
|
||||||
db_subsets.append(db_subset)
|
db_subsets.append(db_subset)
|
||||||
|
|
||||||
@@ -2421,6 +2442,7 @@ class ControlNetDataset(BaseDataset):
|
|||||||
debug_dataset,
|
debug_dataset,
|
||||||
validation_split,
|
validation_split,
|
||||||
validation_seed,
|
validation_seed,
|
||||||
|
resize_interpolation,
|
||||||
)
|
)
|
||||||
|
|
||||||
# config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
|
# config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
|
||||||
@@ -2430,6 +2452,7 @@ class ControlNetDataset(BaseDataset):
|
|||||||
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
|
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
|
||||||
self.validation_split = validation_split
|
self.validation_split = validation_split
|
||||||
self.validation_seed = validation_seed
|
self.validation_seed = validation_seed
|
||||||
|
self.resize_interpolation = resize_interpolation
|
||||||
|
|
||||||
# assert all conditioning data exists
|
# assert all conditioning data exists
|
||||||
missing_imgs = []
|
missing_imgs = []
|
||||||
@@ -2517,8 +2540,10 @@ class ControlNetDataset(BaseDataset):
|
|||||||
assert (
|
assert (
|
||||||
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
|
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
|
||||||
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
|
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
|
||||||
|
|
||||||
|
interpolation = get_cv2_interpolation(self.resize_interpolation)
|
||||||
cond_img = cv2.resize(
|
cond_img = cv2.resize(
|
||||||
cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA
|
cond_img, image_info.resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA
|
||||||
) # INTER_AREAでやりたいのでcv2でリサイズ
|
) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||||
|
|
||||||
# TODO support random crop
|
# TODO support random crop
|
||||||
@@ -2930,7 +2955,7 @@ def load_image(image_path, alpha=False):
|
|||||||
|
|
||||||
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
|
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
|
||||||
def trim_and_resize_if_required(
|
def trim_and_resize_if_required(
|
||||||
random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int]
|
random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int], resize_interpolation: Optional[str] = None
|
||||||
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
|
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
|
||||||
image_height, image_width = image.shape[0:2]
|
image_height, image_width = image.shape[0:2]
|
||||||
original_size = (image_width, image_height) # size before resize
|
original_size = (image_width, image_height) # size before resize
|
||||||
@@ -2938,7 +2963,8 @@ def trim_and_resize_if_required(
|
|||||||
if image_width != resized_size[0] or image_height != resized_size[1]:
|
if image_width != resized_size[0] or image_height != resized_size[1]:
|
||||||
# リサイズする
|
# リサイズする
|
||||||
if image_width > resized_size[0] and image_height > resized_size[1]:
|
if image_width > resized_size[0] and image_height > resized_size[1]:
|
||||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
interpolation = get_cv2_interpolation(resize_interpolation)
|
||||||
|
image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||||
else:
|
else:
|
||||||
image = pil_resize(image, resized_size)
|
image = pil_resize(image, resized_size)
|
||||||
|
|
||||||
@@ -2985,7 +3011,7 @@ def load_images_and_masks_for_caching(
|
|||||||
for info in image_infos:
|
for info in image_infos:
|
||||||
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
|
||||||
|
|
||||||
original_sizes.append(original_size)
|
original_sizes.append(original_size)
|
||||||
crop_ltrbs.append(crop_ltrb)
|
crop_ltrbs.append(crop_ltrb)
|
||||||
@@ -3026,7 +3052,7 @@ def cache_batch_latents(
|
|||||||
for info in image_infos:
|
for info in image_infos:
|
||||||
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
|
||||||
|
|
||||||
info.latents_original_size = original_size
|
info.latents_original_size = original_size
|
||||||
info.latents_crop_ltrb = crop_ltrb
|
info.latents_crop_ltrb = crop_ltrb
|
||||||
@@ -6533,3 +6559,29 @@ class LossRecorder:
|
|||||||
if losses == 0:
|
if losses == 0:
|
||||||
return 0
|
return 0
|
||||||
return self.loss_total / losses
|
return self.loss_total / losses
|
||||||
|
|
||||||
|
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Convert interpolation ovalue to cv2 interpolation integer
|
||||||
|
"""
|
||||||
|
if interpolation is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if interpolation == "lanczos":
|
||||||
|
return cv2.INTER_LANCZOS4
|
||||||
|
elif interpolation == "nearest":
|
||||||
|
return cv2.INTER_NEAREST
|
||||||
|
elif interpolation == "bilinear" or interpolation == "linear":
|
||||||
|
return cv2.INTER_LINEAR
|
||||||
|
elif interpolation == "bicubic" or interpolation == "cubic":
|
||||||
|
return cv2.INTER_CUBIC
|
||||||
|
elif interpolation == "area":
|
||||||
|
return cv2.INTER_AREA
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def validate_interpolation_fn(interpolation_str: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a interpolation function is supported
|
||||||
|
"""
|
||||||
|
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"]
|
||||||
|
|||||||
Reference in New Issue
Block a user