diff --git a/library/config_util.py b/library/config_util.py index a2e07dc6..53727f25 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -75,6 +75,7 @@ class BaseSubsetParams: custom_attributes: Optional[Dict[str, Any]] = None validation_seed: int = 0 validation_split: float = 0.0 + resize_interpolation: Optional[str] = None @dataclass @@ -106,7 +107,7 @@ class BaseDatasetParams: debug_dataset: bool = False validation_seed: Optional[int] = None validation_split: float = 0.0 - + resize_interpolation: Optional[str] = None @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -196,6 +197,7 @@ class ConfigSanitizer: "caption_prefix": str, "caption_suffix": str, "custom_attributes": dict, + "resize_interpolation": str, } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -241,6 +243,7 @@ class ConfigSanitizer: "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, + "resize_interpolation": str, } # options handled by argparse but not handled by user config @@ -525,6 +528,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu [{dataset_type} {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} + resize_interpolation: {dataset.resize_interpolation} enable_bucket: {dataset.enable_bucket} """) @@ -558,6 +562,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, alpha_mask: {subset.alpha_mask} + resize_interpolation: {subset.resize_interpolation} custom_attributes: {subset.custom_attributes} """), " ") diff --git a/library/train_util.py b/library/train_util.py index 39b4af85..a07834ad 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -210,6 +210,7 @@ class ImageInfo: self.text_encoder_pool2: Optional[torch.Tensor] = None self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + self.resize_interpolation: Optional[str] = None class BucketManager: @@ -434,6 +435,7 @@ class BaseSubset: custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + resize_interpolation: Optional[str] = None, ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -464,6 +466,8 @@ class BaseSubset: self.validation_seed = validation_seed self.validation_split = validation_split + self.resize_interpolation = resize_interpolation + class DreamBoothSubset(BaseSubset): def __init__( @@ -495,6 +499,7 @@ class DreamBoothSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + resize_interpolation: Optional[str] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -522,6 +527,7 @@ class DreamBoothSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + resize_interpolation=resize_interpolation, ) self.is_reg = is_reg @@ -564,6 +570,7 @@ class FineTuningSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + resize_interpolation: Optional[str] = None, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -591,6 +598,7 @@ class FineTuningSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + resize_interpolation=resize_interpolation, ) self.metadata_file = metadata_file @@ -629,6 +637,7 @@ class ControlNetSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + resize_interpolation: Optional[str] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -656,6 +665,7 @@ class ControlNetSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + resize_interpolation=resize_interpolation, ) self.conditioning_data_dir = conditioning_data_dir @@ -676,6 +686,7 @@ class BaseDataset(torch.utils.data.Dataset): resolution: Optional[Tuple[int, int]], network_multiplier: float, debug_dataset: bool, + resize_interpolation: Optional[str] = None ) -> None: super().__init__() @@ -710,6 +721,10 @@ class BaseDataset(torch.utils.data.Dataset): self.image_transforms = IMAGE_TRANSFORMS + if resize_interpolation is not None: + assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation" + self.resize_interpolation = resize_interpolation + self.image_data: Dict[str, ImageInfo] = {} self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} @@ -1499,7 +1514,9 @@ class BaseDataset(torch.utils.data.Dataset): nh = int(height * scale + 0.5) nw = int(width * scale + 0.5) assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" - image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) + interpolation = get_cv2_interpolation(subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation) + logger.info(f"Interpolation: {interpolation}") + image = cv2.resize(image, (nw, nh), interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) face_cx = int(face_cx * scale + 0.5) face_cy = int(face_cy * scale + 0.5) height, width = nh, nw @@ -1596,7 +1613,7 @@ class BaseDataset(torch.utils.data.Dataset): if self.enable_bucket: img, original_size, crop_ltrb = trim_and_resize_if_required( - subset.random_crop, img, image_info.bucket_reso, image_info.resized_size + subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation ) else: if face_cx > 0: # 顔位置情報あり @@ -1857,8 +1874,9 @@ class DreamBoothDataset(BaseDataset): debug_dataset: bool, validation_split: float, validation_seed: Optional[int], + resize_interpolation: Optional[str], ) -> None: - super().__init__(resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -2087,6 +2105,7 @@ class DreamBoothDataset(BaseDataset): for img_path, caption, size in zip(img_paths, captions, sizes): info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) + info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation if size is not None: info.image_size = size if subset.is_reg: @@ -2370,8 +2389,9 @@ class ControlNetDataset(BaseDataset): debug_dataset: bool, validation_split: float, validation_seed: Optional[int], + resize_interpolation: Optional[str] = None, ) -> None: - super().__init__(resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) db_subsets = [] for subset in subsets: @@ -2403,6 +2423,7 @@ class ControlNetDataset(BaseDataset): subset.caption_suffix, subset.token_warmup_min, subset.token_warmup_step, + resize_interpolation=subset.resize_interpolation, ) db_subsets.append(db_subset) @@ -2421,6 +2442,7 @@ class ControlNetDataset(BaseDataset): debug_dataset, validation_split, validation_seed, + resize_interpolation, ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) @@ -2430,6 +2452,7 @@ class ControlNetDataset(BaseDataset): self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images self.validation_split = validation_split self.validation_seed = validation_seed + self.resize_interpolation = resize_interpolation # assert all conditioning data exists missing_imgs = [] @@ -2517,8 +2540,10 @@ class ControlNetDataset(BaseDataset): assert ( cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" + + interpolation = get_cv2_interpolation(self.resize_interpolation) cond_img = cv2.resize( - cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA + cond_img, image_info.resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA ) # INTER_AREAでやりたいのでcv2でリサイズ # TODO support random crop @@ -2930,7 +2955,7 @@ def load_image(image_path, alpha=False): # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) def trim_and_resize_if_required( - random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int] + random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int], resize_interpolation: Optional[str] = None ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: image_height, image_width = image.shape[0:2] original_size = (image_width, image_height) # size before resize @@ -2938,7 +2963,8 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする if image_width > resized_size[0] and image_height > resized_size[1]: - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + interpolation = get_cv2_interpolation(resize_interpolation) + image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ else: image = pil_resize(image, resized_size) @@ -2985,7 +3011,7 @@ def load_images_and_masks_for_caching( for info in image_infos: image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) original_sizes.append(original_size) crop_ltrbs.append(crop_ltrb) @@ -3026,7 +3052,7 @@ def cache_batch_latents( for info in image_infos: image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb @@ -6533,3 +6559,29 @@ class LossRecorder: if losses == 0: return 0 return self.loss_total / losses + +def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]: + """ + Convert interpolation ovalue to cv2 interpolation integer + """ + if interpolation is None: + return None + + if interpolation == "lanczos": + return cv2.INTER_LANCZOS4 + elif interpolation == "nearest": + return cv2.INTER_NEAREST + elif interpolation == "bilinear" or interpolation == "linear": + return cv2.INTER_LINEAR + elif interpolation == "bicubic" or interpolation == "cubic": + return cv2.INTER_CUBIC + elif interpolation == "area": + return cv2.INTER_AREA + else: + return None + +def validate_interpolation_fn(interpolation_str: str) -> bool: + """ + Check if a interpolation function is supported + """ + return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"]