Compare commits

...

4 Commits

Author SHA1 Message Date
Kohya S
dbe78a8638 scale crafter 2024-02-25 08:53:43 +09:00
Kohya S
bae116a031 Merge branch 'dev' into flexible-zero-slicing 2024-02-24 22:42:41 +09:00
Kohya S
6aa2d99219 make mask for flexible zero slicing from attncouple mask 2024-02-24 21:38:57 +09:00
Kohya S
725bab124b impl flexible zero slicing 2024-02-24 21:00:38 +09:00
2 changed files with 252 additions and 7 deletions

View File

@@ -1,5 +1,6 @@
import itertools import itertools
import json import json
from types import SimpleNamespace
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
import glob import glob
import importlib import importlib
@@ -2118,6 +2119,37 @@ def main(args):
l.extend([im] * args.images_per_prompt) l.extend([im] * args.images_per_prompt)
mask_images = l mask_images = l
# Flexible Zero Slicing
if args.flexible_zero_slicing_depth is not None:
# CV2 が必要
import cv2
# mask 画像は背景 255、zero にする部分 0 とする
np_mask = np.array(mask_images[0].convert("RGB"))
fz_mask = np.full(np_mask.shape, 255, dtype=np.uint8)
# 各チャンネルに対して処理
for i in range(3):
# チャンネルを抽出
channel = np_mask[:, :, i]
# 輪郭を検出
contours, _ = cv2.findContours(channel, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 輪郭を新しい配列に描画
cv2.drawContours(fz_mask, contours, -1, (0, 0, 0), 1)
fz_mask = fz_mask.astype(np.float32) / 255.0
fz_mask = fz_mask[:, :, 0]
fz_mask = torch.from_numpy(fz_mask).to(dtype).to(device)
# only for sdxl
unet.set_flexible_zero_slicing(fz_mask, args.flexible_zero_slicing_depth, args.flexible_zero_slicing_timesteps)
# Dilated Conv Hires fix
if args.dilated_conv_hires_fix_depth is not None:
unet.set_dilated_conv(args.dilated_conv_hires_fix_depth, args.dilated_conv_hires_fix_timesteps)
# 画像サイズにオプション指定があるときはリサイズする # 画像サイズにオプション指定があるときはリサイズする
if args.W is not None and args.H is not None: if args.W is not None and args.H is not None:
# highres fix を考慮に入れる # highres fix を考慮に入れる
@@ -2146,10 +2178,17 @@ def main(args):
if args.network_regional_mask_max_color_codes: if args.network_regional_mask_max_color_codes:
# カラーコードでマスクを指定する # カラーコードでマスクを指定する
ch0 = (i + 1) & 1 # 0-7: RGB 3bitで8色, 0/255
ch1 = ((i + 1) >> 1) & 1 # 8-15: RGB 3bitで8色, 0/127
ch2 = ((i + 1) >> 2) & 1 code = (i % 7) + 1
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) r = code & 1
g = (code & 2) >> 1
b = (code & 4) >> 2
if i < 7:
color = (r * 255, g * 255, b * 255)
else:
color = (r * 127, g * 127, b * 127)
np_mask = np.all(np_mask == color, axis=2)
np_mask = np_mask.astype(np.uint8) * 255 np_mask = np_mask.astype(np.uint8) * 255
else: else:
np_mask = np_mask[:, :, i] np_mask = np_mask[:, :, i]
@@ -3312,6 +3351,38 @@ def setup_parser() -> argparse.ArgumentParser:
+ " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨",
) )
# parser.add_argument(
# "--flexible_zero_slicing_mask",
# type=str,
# default=None,
# help="mask for flexible zero slicing / flexible zero slicingのマスク",
# )
parser.add_argument(
"--flexible_zero_slicing_depth",
type=int,
default=None,
help="depth for flexible zero slicing / flexible zero slicingのdepth",
)
parser.add_argument(
"--flexible_zero_slicing_timesteps",
type=int,
default=None,
help="timesteps for flexible zero slicing / flexible zero slicingのtimesteps",
)
parser.add_argument(
"--dilated_conv_hires_fix_depth",
type=int,
default=None,
help="depth for dilated conv hires fix / dilated conv hires fixのdepth",
)
parser.add_argument(
"--dilated_conv_hires_fix_timesteps",
type=int,
default=None,
help="timesteps for dilated conv hires fix / dilated conv hires fixのtimesteps",
)
# # parser.add_argument( # # parser.add_argument(
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
# ) # )

View File

@@ -24,15 +24,17 @@
import math import math
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any, Optional from typing import Any, List, Optional
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from einops import rearrange from einops import rearrange
from .utils import setup_logging from .utils import setup_logging
setup_logging() setup_logging()
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
IN_CHANNELS: int = 4 IN_CHANNELS: int = 4
@@ -1114,6 +1116,46 @@ class SdxlUNet2DConditionModel(nn.Module):
return h return h
def get_mask_from_mask_dic(mask_dic, shape):
if mask_dic is None or len(mask_dic) == 0:
return None
mask = mask_dic.get(shape, None)
if mask is None:
# resize from the original mask
mask = mask_dic.get((0, 0), None)
org_dtype = mask.dtype
if org_dtype == torch.bfloat16:
mask = mask.to(torch.float32)
mask = F.interpolate(mask, size=shape, mode="area") # area is needed for keeping the mask value less than 1
mask = (mask == 1).to(dtype=org_dtype, device=mask.device)
mask_dic[shape] = mask
# for m in mask[0,0]:
# print("".join([f"{int(v)}" for v in m]))
return mask
# class Conv2dZeroSlicing(nn.Conv2d):
# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)
# self.mask_dic = None
# self.enable_flag = None
# def set_reference_for_enable_and_mask_dic(self, enable_flag, mask_dic):
# self.enable_flag = enable_flag
# self.mask_dic = mask_dic
# def forward(self, input: torch.Tensor) -> torch.Tensor:
# print(self.enable_flag, self.mask_dic, input.shape[-2:])
# if self.enable_flag is None or not self.enable_flag[0] or self.mask_dic is None or len(self.mask_dic) == 0:
# return super().forward(input)
# mask = get_mask_from_mask_dic(self.mask_dic, input.shape[-2:])
# if mask is not None:
# input = input * mask
# return super().forward(input)
class InferSdxlUNet2DConditionModel: class InferSdxlUNet2DConditionModel:
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs): def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
self.delegate = original_unet self.delegate = original_unet
@@ -1129,10 +1171,96 @@ class InferSdxlUNet2DConditionModel:
self.ds_timesteps_2 = None self.ds_timesteps_2 = None
self.ds_ratio = None self.ds_ratio = None
# Dilated Conv
self.dc_depth = None
self.dc_timesteps = None
self.dc_enable_flag = [False]
for name, module in self.delegate.named_modules():
if isinstance(module, nn.Conv2d):
if module.kernel_size == (3, 3) and module.dilation == (1, 1):
module.dc_enable_flag = self.dc_enable_flag
# replace forward method
module.dc_original_forward = module.forward
def make_forward_dilated_conv(module):
def forward_conv2d_dilated_conv(input: torch.Tensor) -> torch.Tensor:
if module.dc_enable_flag[0]:
module.dilation = (1, 2)
module.padding = (1, 2)
else:
module.dilation = (1, 1)
module.padding = (1, 1)
return module.dc_original_forward(input)
return forward_conv2d_dilated_conv
module.forward = make_forward_dilated_conv(module)
# flexible zero slicing
self.fz_depth = None
self.fz_enable_flag = [False]
self.fz_mask_dic = {}
for name, module in self.delegate.named_modules():
if isinstance(module, nn.Conv2d):
if module.kernel_size == (3, 3):
module.fz_enable_flag = self.fz_enable_flag
module.fz_mask_dic = self.fz_mask_dic
# replace forward method
module.fz_original_forward = module.forward
def make_forward(module):
def forward_conv2d_zero_slicing(input: torch.Tensor) -> torch.Tensor:
if not module.fz_enable_flag[0] or len(module.fz_mask_dic) == 0:
return module.fz_original_forward(input)
mask = get_mask_from_mask_dic(module.fz_mask_dic, input.shape[-2:])
input = input * mask
return module.fz_original_forward(input)
return forward_conv2d_zero_slicing
module.forward = make_forward(module)
# def forward_conv2d_zero_slicing(self, input: torch.Tensor) -> torch.Tensor:
# print(self.__class__.__name__, "forward_conv2d_zero_slicing")
# print(self.enable_flag, self.mask_dic, input.shape[-2:])
# if self.fz_depth is None or not self.fz_enable_flag[0] or self.fz_mask_dic is None or len(self.fz_mask_dic) == 0:
# return self.original_forward(input)
# mask = get_mask_from_mask_dic(self.fz_mask_dic, input.shape[-2:])
# if mask is not None:
# input = input * mask
# return self.original_forward(input)
# for name, module in list(self.delegate.named_modules()):
# if isinstance(module, nn.Conv2d):
# if module.kernel_size == (3, 3):
# # replace Conv2d with Conv2dZeroSlicing
# new_conv2d = Conv2dZeroSlicing(
# module.in_channels,
# module.out_channels,
# module.kernel_size,
# module.stride,
# module.padding,
# module.dilation,
# module.groups,
# module.bias is not None,
# module.padding_mode,
# )
# new_conv2d.set_reference_for_enable_and_mask_dic(self.fz_enable_flag, self.fz_mask_dic)
# print(f"replace {name} with Conv2dZeroSlicing")
# setattr(self.delegate, name, new_conv2d)
# # copy parameters
# new_conv2d.weight = module.weight
# new_conv2d.bias = module.bias
# call original model's methods # call original model's methods
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.delegate, name) return getattr(self.delegate, name)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.delegate(*args, **kwargs) return self.delegate(*args, **kwargs)
@@ -1154,6 +1282,32 @@ class InferSdxlUNet2DConditionModel:
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio self.ds_ratio = ds_ratio
def set_flexible_zero_slicing(self, mask: torch.Tensor, depth: int, timesteps: int = None):
# mask is arbitrary shape, 0 for zero slicing.
if depth is None or depth < 0:
logger.info("Flexible zero slicing is disabled.")
self.fz_depth = None
self.fz_mask = None
self.fz_timesteps = None
self.fz_mask_dic.clear()
else:
logger.info(f"Flexible zero slicing is enabled: [depth={depth}]")
self.fz_depth = depth
self.fz_mask = mask
self.fz_timesteps = timesteps
self.fz_mask_dic.clear()
self.fz_mask_dic[(0, 0)] = mask.unsqueeze(0).unsqueeze(0)
def set_dilated_conv(self, depth: int, timesteps: int = None):
if depth is None or depth < 0:
logger.info("Dilated Conv is disabled.")
self.dc_depth = None
self.dc_timesteps = None
else:
logger.info(f"Dilated Conv is enabled: [depth={depth}]")
self.dc_depth = depth
self.dc_timesteps = timesteps
def forward(self, x, timesteps=None, context=None, y=None, **kwargs): def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
r""" r"""
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink. current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
@@ -1188,7 +1342,18 @@ class InferSdxlUNet2DConditionModel:
# h = x.type(self.dtype) # h = x.type(self.dtype)
h = x h = x
self.fz_enable_flag[0] = False
for depth, module in enumerate(_self.input_blocks): for depth, module in enumerate(_self.input_blocks):
# Dilated Conv
if self.dc_depth is not None:
self.dc_enable_flag[0] = depth >= self.dc_depth and timesteps[0] > self.dc_timesteps
# Flexible Zero Slicing
if self.fz_depth is not None:
self.fz_enable_flag[0] = depth >= self.fz_depth and timesteps[0] > self.fz_timesteps
# print(f"Flexible Zero Slicing: depth={depth}, timesteps={timesteps[0]}, enable={self.fz_enable_flag[0]}")
# Deep Shrink # Deep Shrink
if self.ds_depth_1 is not None: if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
@@ -1208,7 +1373,16 @@ class InferSdxlUNet2DConditionModel:
h = call_module(_self.middle_block, h, emb, context) h = call_module(_self.middle_block, h, emb, context)
for module in _self.output_blocks: for depth, module in enumerate(_self.output_blocks):
# Dilated Conv
if self.dc_depth is not None and len(_self.output_blocks) - depth <= self.dc_depth:
self.dc_enable_flag[0] = False
# Flexible Zero Slicing
if self.fz_depth is not None and len(self.output_blocks) - depth <= self.fz_depth:
self.fz_enable_flag[0] = False
# print(f"Flexible Zero Slicing: depth={depth}, timesteps={timesteps[0]}, enable={self.fz_enable_flag[0]}")
# Deep Shrink # Deep Shrink
if self.ds_depth_1 is not None: if self.ds_depth_1 is not None:
if hs[-1].shape[-2:] != h.shape[-2:]: if hs[-1].shape[-2:] != h.shape[-2:]: