From dbe78a8638bcd55cf781442ef2a1cee2144fa521 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 25 Feb 2024 08:53:43 +0900 Subject: [PATCH] scale crafter --- gen_img.py | 20 +++++++++++- library/sdxl_original_unet.py | 58 ++++++++++++++++++++++++++++++----- 2 files changed, 70 insertions(+), 8 deletions(-) diff --git a/gen_img.py b/gen_img.py index a00f8124..293f8d21 100644 --- a/gen_img.py +++ b/gen_img.py @@ -2138,7 +2138,7 @@ def main(args): # 輪郭を新しい配列に描画 cv2.drawContours(fz_mask, contours, -1, (0, 0, 0), 1) - + fz_mask = fz_mask.astype(np.float32) / 255.0 fz_mask = fz_mask[:, :, 0] fz_mask = torch.from_numpy(fz_mask).to(dtype).to(device) @@ -2146,6 +2146,10 @@ def main(args): # only for sdxl unet.set_flexible_zero_slicing(fz_mask, args.flexible_zero_slicing_depth, args.flexible_zero_slicing_timesteps) + # Dilated Conv Hires fix + if args.dilated_conv_hires_fix_depth is not None: + unet.set_dilated_conv(args.dilated_conv_hires_fix_depth, args.dilated_conv_hires_fix_timesteps) + # 画像サイズにオプション指定があるときはリサイズする if args.W is not None and args.H is not None: # highres fix を考慮に入れる @@ -3365,6 +3369,20 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="timesteps for flexible zero slicing / flexible zero slicingのtimesteps", ) + + parser.add_argument( + "--dilated_conv_hires_fix_depth", + type=int, + default=None, + help="depth for dilated conv hires fix / dilated conv hires fixのdepth", + ) + parser.add_argument( + "--dilated_conv_hires_fix_timesteps", + type=int, + default=None, + help="timesteps for dilated conv hires fix / dilated conv hires fixのtimesteps", + ) + # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # ) diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 2aaa65c3..f92d9442 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -1171,6 +1171,32 @@ class InferSdxlUNet2DConditionModel: self.ds_timesteps_2 = None self.ds_ratio = None + # Dilated Conv + self.dc_depth = None + self.dc_timesteps = None + self.dc_enable_flag = [False] + for name, module in self.delegate.named_modules(): + if isinstance(module, nn.Conv2d): + if module.kernel_size == (3, 3) and module.dilation == (1, 1): + module.dc_enable_flag = self.dc_enable_flag + + # replace forward method + module.dc_original_forward = module.forward + + def make_forward_dilated_conv(module): + def forward_conv2d_dilated_conv(input: torch.Tensor) -> torch.Tensor: + if module.dc_enable_flag[0]: + module.dilation = (1, 2) + module.padding = (1, 2) + else: + module.dilation = (1, 1) + module.padding = (1, 1) + return module.dc_original_forward(input) + + return forward_conv2d_dilated_conv + + module.forward = make_forward_dilated_conv(module) + # flexible zero slicing self.fz_depth = None self.fz_enable_flag = [False] @@ -1178,20 +1204,20 @@ class InferSdxlUNet2DConditionModel: for name, module in self.delegate.named_modules(): if isinstance(module, nn.Conv2d): if module.kernel_size == (3, 3): - module.enable_flag = self.fz_enable_flag - module.mask_dic = self.fz_mask_dic + module.fz_enable_flag = self.fz_enable_flag + module.fz_mask_dic = self.fz_mask_dic # replace forward method - module.original_forward = module.forward + module.fz_original_forward = module.forward def make_forward(module): def forward_conv2d_zero_slicing(input: torch.Tensor) -> torch.Tensor: - if not module.enable_flag[0] or len(module.mask_dic) == 0: - return module.original_forward(input) + if not module.fz_enable_flag[0] or len(module.fz_mask_dic) == 0: + return module.fz_original_forward(input) - mask = get_mask_from_mask_dic(module.mask_dic, input.shape[-2:]) + mask = get_mask_from_mask_dic(module.fz_mask_dic, input.shape[-2:]) input = input * mask - return module.original_forward(input) + return module.fz_original_forward(input) return forward_conv2d_zero_slicing @@ -1272,6 +1298,16 @@ class InferSdxlUNet2DConditionModel: self.fz_mask_dic.clear() self.fz_mask_dic[(0, 0)] = mask.unsqueeze(0).unsqueeze(0) + def set_dilated_conv(self, depth: int, timesteps: int = None): + if depth is None or depth < 0: + logger.info("Dilated Conv is disabled.") + self.dc_depth = None + self.dc_timesteps = None + else: + logger.info(f"Dilated Conv is enabled: [depth={depth}]") + self.dc_depth = depth + self.dc_timesteps = timesteps + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): r""" current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink. @@ -1309,6 +1345,10 @@ class InferSdxlUNet2DConditionModel: self.fz_enable_flag[0] = False for depth, module in enumerate(_self.input_blocks): + # Dilated Conv + if self.dc_depth is not None: + self.dc_enable_flag[0] = depth >= self.dc_depth and timesteps[0] > self.dc_timesteps + # Flexible Zero Slicing if self.fz_depth is not None: self.fz_enable_flag[0] = depth >= self.fz_depth and timesteps[0] > self.fz_timesteps @@ -1334,6 +1374,10 @@ class InferSdxlUNet2DConditionModel: h = call_module(_self.middle_block, h, emb, context) for depth, module in enumerate(_self.output_blocks): + # Dilated Conv + if self.dc_depth is not None and len(_self.output_blocks) - depth <= self.dc_depth: + self.dc_enable_flag[0] = False + # Flexible Zero Slicing if self.fz_depth is not None and len(self.output_blocks) - depth <= self.fz_depth: self.fz_enable_flag[0] = False