impl flexible zero slicing

This commit is contained in:
Kohya S
2024-02-24 21:00:38 +09:00
parent d1fb480887
commit 725bab124b
2 changed files with 177 additions and 9 deletions

View File

@@ -1,5 +1,6 @@
import itertools
import json
from types import SimpleNamespace
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
import glob
import importlib
@@ -20,7 +21,7 @@ import diffusers
import numpy as np
import torch
from library.ipex_interop import init_ipex
from library.device_utils import init_ipex, clean_memory, get_preferred_device
init_ipex()
@@ -338,7 +339,7 @@ class PipelineLike:
self.clip_vision_model: CLIPVisionModelWithProjection = None
self.clip_vision_processor: CLIPImageProcessor = None
self.clip_vision_strength = 0.0
# Textual Inversion
self.token_replacements_list = []
for _ in range(len(self.text_encoders)):
@@ -1926,6 +1927,18 @@ def main(args):
)
pipe.set_gradual_latent(gradual_latent)
# Flexible Zero Slicing
if args.flexible_zero_slicing_mask:
# mask 画像は背景 255、zero にする部分 0 とする
print(f"loading Flexible Zero Slicing mask")
fz_mask = Image.open(args.flexible_zero_slicing_mask).convert("RGB")
fz_mask = np.array(fz_mask).astype(np.float32) / 255.0
fz_mask = fz_mask[:, :, 0]
fz_mask = torch.from_numpy(fz_mask).to(dtype).to(device)
# only for sdxl
unet.set_flexible_zero_slicing(fz_mask, args.flexible_zero_slicing_depth, args.flexible_zero_slicing_timesteps)
# Textual Inversionを処理する
if args.textual_inversion_embeddings:
token_ids_embeds1 = []
@@ -2146,10 +2159,17 @@ def main(args):
if args.network_regional_mask_max_color_codes:
# カラーコードでマスクを指定する
ch0 = (i + 1) & 1
ch1 = ((i + 1) >> 1) & 1
ch2 = ((i + 1) >> 2) & 1
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
# 0-7: RGB 3bitで8色, 0/255
# 8-15: RGB 3bitで8色, 0/127
code = (i % 7) + 1
r = code & 1
g = (code & 2) >> 1
b = (code & 4) >> 2
if i < 7:
color = (r * 255, g * 255, b * 255)
else:
color = (r * 127, g * 127, b * 127)
np_mask = np.all(np_mask == color, axis=2)
np_mask = np_mask.astype(np.uint8) * 255
else:
np_mask = np_mask[:, :, i]
@@ -3312,6 +3332,24 @@ def setup_parser() -> argparse.ArgumentParser:
+ " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨",
)
parser.add_argument(
"--flexible_zero_slicing_mask",
type=str,
default=None,
help="mask for flexible zero slicing / flexible zero slicingのマスク",
)
parser.add_argument(
"--flexible_zero_slicing_depth",
type=int,
default=None,
help="depth for flexible zero slicing / flexible zero slicingのdepth",
)
parser.add_argument(
"--flexible_zero_slicing_timesteps",
type=int,
default=None,
help="timesteps for flexible zero slicing / flexible zero slicingのtimesteps",
)
# # parser.add_argument(
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
# )

View File

@@ -24,15 +24,17 @@
import math
from types import SimpleNamespace
from typing import Any, Optional
from typing import Any, List, Optional
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
IN_CHANNELS: int = 4
@@ -1114,6 +1116,46 @@ class SdxlUNet2DConditionModel(nn.Module):
return h
def get_mask_from_mask_dic(mask_dic, shape):
if mask_dic is None or len(mask_dic) == 0:
return None
mask = mask_dic.get(shape, None)
if mask is None:
# resize from the original mask
mask = mask_dic.get((0, 0), None)
org_dtype = mask.dtype
if org_dtype == torch.bfloat16:
mask = mask.to(torch.float32)
mask = F.interpolate(mask, size=shape, mode="area") # area is needed for keeping the mask value less than 1
mask = (mask == 1).to(dtype=org_dtype, device=mask.device)
mask_dic[shape] = mask
# for m in mask[0,0]:
# print("".join([f"{int(v)}" for v in m]))
return mask
# class Conv2dZeroSlicing(nn.Conv2d):
# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)
# self.mask_dic = None
# self.enable_flag = None
# def set_reference_for_enable_and_mask_dic(self, enable_flag, mask_dic):
# self.enable_flag = enable_flag
# self.mask_dic = mask_dic
# def forward(self, input: torch.Tensor) -> torch.Tensor:
# print(self.enable_flag, self.mask_dic, input.shape[-2:])
# if self.enable_flag is None or not self.enable_flag[0] or self.mask_dic is None or len(self.mask_dic) == 0:
# return super().forward(input)
# mask = get_mask_from_mask_dic(self.mask_dic, input.shape[-2:])
# if mask is not None:
# input = input * mask
# return super().forward(input)
class InferSdxlUNet2DConditionModel:
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
self.delegate = original_unet
@@ -1129,10 +1171,70 @@ class InferSdxlUNet2DConditionModel:
self.ds_timesteps_2 = None
self.ds_ratio = None
# flexible zero slicing
self.fz_depth = None
self.fz_enable_flag = [False]
self.fz_mask_dic = {}
for name, module in self.delegate.named_modules():
if isinstance(module, nn.Conv2d):
if module.kernel_size == (3, 3):
module.enable_flag = self.fz_enable_flag
module.mask_dic = self.fz_mask_dic
# replace forward method
module.original_forward = module.forward
def make_forward(module):
def forward_conv2d_zero_slicing(input: torch.Tensor) -> torch.Tensor:
if not module.enable_flag[0] or len(module.mask_dic) == 0:
return module.original_forward(input)
mask = get_mask_from_mask_dic(module.mask_dic, input.shape[-2:])
input = input * mask
return module.original_forward(input)
return forward_conv2d_zero_slicing
module.forward = make_forward(module)
# def forward_conv2d_zero_slicing(self, input: torch.Tensor) -> torch.Tensor:
# print(self.__class__.__name__, "forward_conv2d_zero_slicing")
# print(self.enable_flag, self.mask_dic, input.shape[-2:])
# if self.fz_depth is None or not self.fz_enable_flag[0] or self.fz_mask_dic is None or len(self.fz_mask_dic) == 0:
# return self.original_forward(input)
# mask = get_mask_from_mask_dic(self.fz_mask_dic, input.shape[-2:])
# if mask is not None:
# input = input * mask
# return self.original_forward(input)
# for name, module in list(self.delegate.named_modules()):
# if isinstance(module, nn.Conv2d):
# if module.kernel_size == (3, 3):
# # replace Conv2d with Conv2dZeroSlicing
# new_conv2d = Conv2dZeroSlicing(
# module.in_channels,
# module.out_channels,
# module.kernel_size,
# module.stride,
# module.padding,
# module.dilation,
# module.groups,
# module.bias is not None,
# module.padding_mode,
# )
# new_conv2d.set_reference_for_enable_and_mask_dic(self.fz_enable_flag, self.fz_mask_dic)
# print(f"replace {name} with Conv2dZeroSlicing")
# setattr(self.delegate, name, new_conv2d)
# # copy parameters
# new_conv2d.weight = module.weight
# new_conv2d.bias = module.bias
# call original model's methods
def __getattr__(self, name):
return getattr(self.delegate, name)
def __call__(self, *args, **kwargs):
return self.delegate(*args, **kwargs)
@@ -1154,6 +1256,22 @@ class InferSdxlUNet2DConditionModel:
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def set_flexible_zero_slicing(self, mask: torch.Tensor, depth: int, timesteps: int = None):
# mask is arbitrary shape, 0 for zero slicing.
if depth is None or depth < 0:
logger.info("Flexible zero slicing is disabled.")
self.fz_depth = None
self.fz_mask = None
self.fz_timesteps = None
self.fz_mask_dic.clear()
else:
logger.info(f"Flexible zero slicing is enabled: [depth={depth}]")
self.fz_depth = depth
self.fz_mask = mask
self.fz_timesteps = timesteps
self.fz_mask_dic.clear()
self.fz_mask_dic[(0, 0)] = mask.unsqueeze(0).unsqueeze(0)
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
r"""
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
@@ -1188,7 +1306,14 @@ class InferSdxlUNet2DConditionModel:
# h = x.type(self.dtype)
h = x
self.fz_enable_flag[0] = False
for depth, module in enumerate(_self.input_blocks):
# Flexible Zero Slicing
if self.fz_depth is not None:
self.fz_enable_flag[0] = depth >= self.fz_depth and timesteps[0] > self.fz_timesteps
# print(f"Flexible Zero Slicing: depth={depth}, timesteps={timesteps[0]}, enable={self.fz_enable_flag[0]}")
# Deep Shrink
if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
@@ -1208,7 +1333,12 @@ class InferSdxlUNet2DConditionModel:
h = call_module(_self.middle_block, h, emb, context)
for module in _self.output_blocks:
for depth, module in enumerate(_self.output_blocks):
# Flexible Zero Slicing
if self.fz_depth is not None and len(self.output_blocks) - depth <= self.fz_depth:
self.fz_enable_flag[0] = False
# print(f"Flexible Zero Slicing: depth={depth}, timesteps={timesteps[0]}, enable={self.fz_enable_flag[0]}")
# Deep Shrink
if self.ds_depth_1 is not None:
if hs[-1].shape[-2:] != h.shape[-2:]: