diff --git a/gen_img.py b/gen_img.py index a24220a0..4395b790 100644 --- a/gen_img.py +++ b/gen_img.py @@ -1,5 +1,6 @@ import itertools import json +from types import SimpleNamespace from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable import glob import importlib @@ -20,7 +21,7 @@ import diffusers import numpy as np import torch -from library.ipex_interop import init_ipex +from library.device_utils import init_ipex, clean_memory, get_preferred_device init_ipex() @@ -338,7 +339,7 @@ class PipelineLike: self.clip_vision_model: CLIPVisionModelWithProjection = None self.clip_vision_processor: CLIPImageProcessor = None self.clip_vision_strength = 0.0 - + # Textual Inversion self.token_replacements_list = [] for _ in range(len(self.text_encoders)): @@ -1926,6 +1927,18 @@ def main(args): ) pipe.set_gradual_latent(gradual_latent) + # Flexible Zero Slicing + if args.flexible_zero_slicing_mask: + # mask 画像は背景 255、zero にする部分 0 とする + print(f"loading Flexible Zero Slicing mask") + fz_mask = Image.open(args.flexible_zero_slicing_mask).convert("RGB") + fz_mask = np.array(fz_mask).astype(np.float32) / 255.0 + fz_mask = fz_mask[:, :, 0] + fz_mask = torch.from_numpy(fz_mask).to(dtype).to(device) + + # only for sdxl + unet.set_flexible_zero_slicing(fz_mask, args.flexible_zero_slicing_depth, args.flexible_zero_slicing_timesteps) + # Textual Inversionを処理する if args.textual_inversion_embeddings: token_ids_embeds1 = [] @@ -2146,10 +2159,17 @@ def main(args): if args.network_regional_mask_max_color_codes: # カラーコードでマスクを指定する - ch0 = (i + 1) & 1 - ch1 = ((i + 1) >> 1) & 1 - ch2 = ((i + 1) >> 2) & 1 - np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) + # 0-7: RGB 3bitで8色, 0/255 + # 8-15: RGB 3bitで8色, 0/127 + code = (i % 7) + 1 + r = code & 1 + g = (code & 2) >> 1 + b = (code & 4) >> 2 + if i < 7: + color = (r * 255, g * 255, b * 255) + else: + color = (r * 127, g * 127, b * 127) + np_mask = np.all(np_mask == color, axis=2) np_mask = np_mask.astype(np.uint8) * 255 else: np_mask = np_mask[:, :, i] @@ -3312,6 +3332,24 @@ def setup_parser() -> argparse.ArgumentParser: + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", ) + parser.add_argument( + "--flexible_zero_slicing_mask", + type=str, + default=None, + help="mask for flexible zero slicing / flexible zero slicingのマスク", + ) + parser.add_argument( + "--flexible_zero_slicing_depth", + type=int, + default=None, + help="depth for flexible zero slicing / flexible zero slicingのdepth", + ) + parser.add_argument( + "--flexible_zero_slicing_timesteps", + type=int, + default=None, + help="timesteps for flexible zero slicing / flexible zero slicingのtimesteps", + ) # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # ) diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 673cf9f6..2aaa65c3 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -24,15 +24,17 @@ import math from types import SimpleNamespace -from typing import Any, Optional +from typing import Any, List, Optional import torch import torch.utils.checkpoint from torch import nn from torch.nn import functional as F from einops import rearrange from .utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) IN_CHANNELS: int = 4 @@ -1114,6 +1116,46 @@ class SdxlUNet2DConditionModel(nn.Module): return h +def get_mask_from_mask_dic(mask_dic, shape): + if mask_dic is None or len(mask_dic) == 0: + return None + mask = mask_dic.get(shape, None) + if mask is None: + # resize from the original mask + mask = mask_dic.get((0, 0), None) + org_dtype = mask.dtype + if org_dtype == torch.bfloat16: + mask = mask.to(torch.float32) + mask = F.interpolate(mask, size=shape, mode="area") # area is needed for keeping the mask value less than 1 + mask = (mask == 1).to(dtype=org_dtype, device=mask.device) + mask_dic[shape] = mask + # for m in mask[0,0]: + # print("".join([f"{int(v)}" for v in m])) + + return mask + + +# class Conv2dZeroSlicing(nn.Conv2d): +# def __init__(self, *args, **kwargs): +# super().__init__(*args, **kwargs) +# self.mask_dic = None +# self.enable_flag = None + +# def set_reference_for_enable_and_mask_dic(self, enable_flag, mask_dic): +# self.enable_flag = enable_flag +# self.mask_dic = mask_dic + +# def forward(self, input: torch.Tensor) -> torch.Tensor: +# print(self.enable_flag, self.mask_dic, input.shape[-2:]) +# if self.enable_flag is None or not self.enable_flag[0] or self.mask_dic is None or len(self.mask_dic) == 0: +# return super().forward(input) + +# mask = get_mask_from_mask_dic(self.mask_dic, input.shape[-2:]) +# if mask is not None: +# input = input * mask +# return super().forward(input) + + class InferSdxlUNet2DConditionModel: def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs): self.delegate = original_unet @@ -1129,10 +1171,70 @@ class InferSdxlUNet2DConditionModel: self.ds_timesteps_2 = None self.ds_ratio = None + # flexible zero slicing + self.fz_depth = None + self.fz_enable_flag = [False] + self.fz_mask_dic = {} + for name, module in self.delegate.named_modules(): + if isinstance(module, nn.Conv2d): + if module.kernel_size == (3, 3): + module.enable_flag = self.fz_enable_flag + module.mask_dic = self.fz_mask_dic + + # replace forward method + module.original_forward = module.forward + + def make_forward(module): + def forward_conv2d_zero_slicing(input: torch.Tensor) -> torch.Tensor: + if not module.enable_flag[0] or len(module.mask_dic) == 0: + return module.original_forward(input) + + mask = get_mask_from_mask_dic(module.mask_dic, input.shape[-2:]) + input = input * mask + return module.original_forward(input) + + return forward_conv2d_zero_slicing + + module.forward = make_forward(module) + + # def forward_conv2d_zero_slicing(self, input: torch.Tensor) -> torch.Tensor: + # print(self.__class__.__name__, "forward_conv2d_zero_slicing") + # print(self.enable_flag, self.mask_dic, input.shape[-2:]) + # if self.fz_depth is None or not self.fz_enable_flag[0] or self.fz_mask_dic is None or len(self.fz_mask_dic) == 0: + # return self.original_forward(input) + + # mask = get_mask_from_mask_dic(self.fz_mask_dic, input.shape[-2:]) + # if mask is not None: + # input = input * mask + # return self.original_forward(input) + + # for name, module in list(self.delegate.named_modules()): + # if isinstance(module, nn.Conv2d): + # if module.kernel_size == (3, 3): + # # replace Conv2d with Conv2dZeroSlicing + # new_conv2d = Conv2dZeroSlicing( + # module.in_channels, + # module.out_channels, + # module.kernel_size, + # module.stride, + # module.padding, + # module.dilation, + # module.groups, + # module.bias is not None, + # module.padding_mode, + # ) + # new_conv2d.set_reference_for_enable_and_mask_dic(self.fz_enable_flag, self.fz_mask_dic) + # print(f"replace {name} with Conv2dZeroSlicing") + # setattr(self.delegate, name, new_conv2d) + + # # copy parameters + # new_conv2d.weight = module.weight + # new_conv2d.bias = module.bias + # call original model's methods def __getattr__(self, name): return getattr(self.delegate, name) - + def __call__(self, *args, **kwargs): return self.delegate(*args, **kwargs) @@ -1154,6 +1256,22 @@ class InferSdxlUNet2DConditionModel: self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 self.ds_ratio = ds_ratio + def set_flexible_zero_slicing(self, mask: torch.Tensor, depth: int, timesteps: int = None): + # mask is arbitrary shape, 0 for zero slicing. + if depth is None or depth < 0: + logger.info("Flexible zero slicing is disabled.") + self.fz_depth = None + self.fz_mask = None + self.fz_timesteps = None + self.fz_mask_dic.clear() + else: + logger.info(f"Flexible zero slicing is enabled: [depth={depth}]") + self.fz_depth = depth + self.fz_mask = mask + self.fz_timesteps = timesteps + self.fz_mask_dic.clear() + self.fz_mask_dic[(0, 0)] = mask.unsqueeze(0).unsqueeze(0) + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): r""" current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink. @@ -1188,7 +1306,14 @@ class InferSdxlUNet2DConditionModel: # h = x.type(self.dtype) h = x + self.fz_enable_flag[0] = False + for depth, module in enumerate(_self.input_blocks): + # Flexible Zero Slicing + if self.fz_depth is not None: + self.fz_enable_flag[0] = depth >= self.fz_depth and timesteps[0] > self.fz_timesteps + # print(f"Flexible Zero Slicing: depth={depth}, timesteps={timesteps[0]}, enable={self.fz_enable_flag[0]}") + # Deep Shrink if self.ds_depth_1 is not None: if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( @@ -1208,7 +1333,12 @@ class InferSdxlUNet2DConditionModel: h = call_module(_self.middle_block, h, emb, context) - for module in _self.output_blocks: + for depth, module in enumerate(_self.output_blocks): + # Flexible Zero Slicing + if self.fz_depth is not None and len(self.output_blocks) - depth <= self.fz_depth: + self.fz_enable_flag[0] = False + # print(f"Flexible Zero Slicing: depth={depth}, timesteps={timesteps[0]}, enable={self.fz_enable_flag[0]}") + # Deep Shrink if self.ds_depth_1 is not None: if hs[-1].shape[-2:] != h.shape[-2:]: