mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Compare commits
4 Commits
galore
...
flexible-z
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dbe78a8638 | ||
|
|
bae116a031 | ||
|
|
6aa2d99219 | ||
|
|
725bab124b |
79
gen_img.py
79
gen_img.py
@@ -1,5 +1,6 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
@@ -2118,6 +2119,37 @@ def main(args):
|
|||||||
l.extend([im] * args.images_per_prompt)
|
l.extend([im] * args.images_per_prompt)
|
||||||
mask_images = l
|
mask_images = l
|
||||||
|
|
||||||
|
# Flexible Zero Slicing
|
||||||
|
if args.flexible_zero_slicing_depth is not None:
|
||||||
|
# CV2 が必要
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
# mask 画像は背景 255、zero にする部分 0 とする
|
||||||
|
np_mask = np.array(mask_images[0].convert("RGB"))
|
||||||
|
fz_mask = np.full(np_mask.shape, 255, dtype=np.uint8)
|
||||||
|
|
||||||
|
# 各チャンネルに対して処理
|
||||||
|
for i in range(3):
|
||||||
|
# チャンネルを抽出
|
||||||
|
channel = np_mask[:, :, i]
|
||||||
|
|
||||||
|
# 輪郭を検出
|
||||||
|
contours, _ = cv2.findContours(channel, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
|
# 輪郭を新しい配列に描画
|
||||||
|
cv2.drawContours(fz_mask, contours, -1, (0, 0, 0), 1)
|
||||||
|
|
||||||
|
fz_mask = fz_mask.astype(np.float32) / 255.0
|
||||||
|
fz_mask = fz_mask[:, :, 0]
|
||||||
|
fz_mask = torch.from_numpy(fz_mask).to(dtype).to(device)
|
||||||
|
|
||||||
|
# only for sdxl
|
||||||
|
unet.set_flexible_zero_slicing(fz_mask, args.flexible_zero_slicing_depth, args.flexible_zero_slicing_timesteps)
|
||||||
|
|
||||||
|
# Dilated Conv Hires fix
|
||||||
|
if args.dilated_conv_hires_fix_depth is not None:
|
||||||
|
unet.set_dilated_conv(args.dilated_conv_hires_fix_depth, args.dilated_conv_hires_fix_timesteps)
|
||||||
|
|
||||||
# 画像サイズにオプション指定があるときはリサイズする
|
# 画像サイズにオプション指定があるときはリサイズする
|
||||||
if args.W is not None and args.H is not None:
|
if args.W is not None and args.H is not None:
|
||||||
# highres fix を考慮に入れる
|
# highres fix を考慮に入れる
|
||||||
@@ -2146,10 +2178,17 @@ def main(args):
|
|||||||
|
|
||||||
if args.network_regional_mask_max_color_codes:
|
if args.network_regional_mask_max_color_codes:
|
||||||
# カラーコードでマスクを指定する
|
# カラーコードでマスクを指定する
|
||||||
ch0 = (i + 1) & 1
|
# 0-7: RGB 3bitで8色, 0/255
|
||||||
ch1 = ((i + 1) >> 1) & 1
|
# 8-15: RGB 3bitで8色, 0/127
|
||||||
ch2 = ((i + 1) >> 2) & 1
|
code = (i % 7) + 1
|
||||||
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
|
r = code & 1
|
||||||
|
g = (code & 2) >> 1
|
||||||
|
b = (code & 4) >> 2
|
||||||
|
if i < 7:
|
||||||
|
color = (r * 255, g * 255, b * 255)
|
||||||
|
else:
|
||||||
|
color = (r * 127, g * 127, b * 127)
|
||||||
|
np_mask = np.all(np_mask == color, axis=2)
|
||||||
np_mask = np_mask.astype(np.uint8) * 255
|
np_mask = np_mask.astype(np.uint8) * 255
|
||||||
else:
|
else:
|
||||||
np_mask = np_mask[:, :, i]
|
np_mask = np_mask[:, :, i]
|
||||||
@@ -3312,6 +3351,38 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
+ " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨",
|
+ " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# parser.add_argument(
|
||||||
|
# "--flexible_zero_slicing_mask",
|
||||||
|
# type=str,
|
||||||
|
# default=None,
|
||||||
|
# help="mask for flexible zero slicing / flexible zero slicingのマスク",
|
||||||
|
# )
|
||||||
|
parser.add_argument(
|
||||||
|
"--flexible_zero_slicing_depth",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="depth for flexible zero slicing / flexible zero slicingのdepth",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--flexible_zero_slicing_timesteps",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="timesteps for flexible zero slicing / flexible zero slicingのtimesteps",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dilated_conv_hires_fix_depth",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="depth for dilated conv hires fix / dilated conv hires fixのdepth",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dilated_conv_hires_fix_timesteps",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="timesteps for dilated conv hires fix / dilated conv hires fixのtimesteps",
|
||||||
|
)
|
||||||
|
|
||||||
# # parser.add_argument(
|
# # parser.add_argument(
|
||||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||||
# )
|
# )
|
||||||
|
|||||||
@@ -24,15 +24,17 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Any, Optional
|
from typing import Any, List, Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from .utils import setup_logging
|
from .utils import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
IN_CHANNELS: int = 4
|
IN_CHANNELS: int = 4
|
||||||
@@ -1114,6 +1116,46 @@ class SdxlUNet2DConditionModel(nn.Module):
|
|||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
def get_mask_from_mask_dic(mask_dic, shape):
|
||||||
|
if mask_dic is None or len(mask_dic) == 0:
|
||||||
|
return None
|
||||||
|
mask = mask_dic.get(shape, None)
|
||||||
|
if mask is None:
|
||||||
|
# resize from the original mask
|
||||||
|
mask = mask_dic.get((0, 0), None)
|
||||||
|
org_dtype = mask.dtype
|
||||||
|
if org_dtype == torch.bfloat16:
|
||||||
|
mask = mask.to(torch.float32)
|
||||||
|
mask = F.interpolate(mask, size=shape, mode="area") # area is needed for keeping the mask value less than 1
|
||||||
|
mask = (mask == 1).to(dtype=org_dtype, device=mask.device)
|
||||||
|
mask_dic[shape] = mask
|
||||||
|
# for m in mask[0,0]:
|
||||||
|
# print("".join([f"{int(v)}" for v in m]))
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
# class Conv2dZeroSlicing(nn.Conv2d):
|
||||||
|
# def __init__(self, *args, **kwargs):
|
||||||
|
# super().__init__(*args, **kwargs)
|
||||||
|
# self.mask_dic = None
|
||||||
|
# self.enable_flag = None
|
||||||
|
|
||||||
|
# def set_reference_for_enable_and_mask_dic(self, enable_flag, mask_dic):
|
||||||
|
# self.enable_flag = enable_flag
|
||||||
|
# self.mask_dic = mask_dic
|
||||||
|
|
||||||
|
# def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
# print(self.enable_flag, self.mask_dic, input.shape[-2:])
|
||||||
|
# if self.enable_flag is None or not self.enable_flag[0] or self.mask_dic is None or len(self.mask_dic) == 0:
|
||||||
|
# return super().forward(input)
|
||||||
|
|
||||||
|
# mask = get_mask_from_mask_dic(self.mask_dic, input.shape[-2:])
|
||||||
|
# if mask is not None:
|
||||||
|
# input = input * mask
|
||||||
|
# return super().forward(input)
|
||||||
|
|
||||||
|
|
||||||
class InferSdxlUNet2DConditionModel:
|
class InferSdxlUNet2DConditionModel:
|
||||||
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
|
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
|
||||||
self.delegate = original_unet
|
self.delegate = original_unet
|
||||||
@@ -1129,10 +1171,96 @@ class InferSdxlUNet2DConditionModel:
|
|||||||
self.ds_timesteps_2 = None
|
self.ds_timesteps_2 = None
|
||||||
self.ds_ratio = None
|
self.ds_ratio = None
|
||||||
|
|
||||||
|
# Dilated Conv
|
||||||
|
self.dc_depth = None
|
||||||
|
self.dc_timesteps = None
|
||||||
|
self.dc_enable_flag = [False]
|
||||||
|
for name, module in self.delegate.named_modules():
|
||||||
|
if isinstance(module, nn.Conv2d):
|
||||||
|
if module.kernel_size == (3, 3) and module.dilation == (1, 1):
|
||||||
|
module.dc_enable_flag = self.dc_enable_flag
|
||||||
|
|
||||||
|
# replace forward method
|
||||||
|
module.dc_original_forward = module.forward
|
||||||
|
|
||||||
|
def make_forward_dilated_conv(module):
|
||||||
|
def forward_conv2d_dilated_conv(input: torch.Tensor) -> torch.Tensor:
|
||||||
|
if module.dc_enable_flag[0]:
|
||||||
|
module.dilation = (1, 2)
|
||||||
|
module.padding = (1, 2)
|
||||||
|
else:
|
||||||
|
module.dilation = (1, 1)
|
||||||
|
module.padding = (1, 1)
|
||||||
|
return module.dc_original_forward(input)
|
||||||
|
|
||||||
|
return forward_conv2d_dilated_conv
|
||||||
|
|
||||||
|
module.forward = make_forward_dilated_conv(module)
|
||||||
|
|
||||||
|
# flexible zero slicing
|
||||||
|
self.fz_depth = None
|
||||||
|
self.fz_enable_flag = [False]
|
||||||
|
self.fz_mask_dic = {}
|
||||||
|
for name, module in self.delegate.named_modules():
|
||||||
|
if isinstance(module, nn.Conv2d):
|
||||||
|
if module.kernel_size == (3, 3):
|
||||||
|
module.fz_enable_flag = self.fz_enable_flag
|
||||||
|
module.fz_mask_dic = self.fz_mask_dic
|
||||||
|
|
||||||
|
# replace forward method
|
||||||
|
module.fz_original_forward = module.forward
|
||||||
|
|
||||||
|
def make_forward(module):
|
||||||
|
def forward_conv2d_zero_slicing(input: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not module.fz_enable_flag[0] or len(module.fz_mask_dic) == 0:
|
||||||
|
return module.fz_original_forward(input)
|
||||||
|
|
||||||
|
mask = get_mask_from_mask_dic(module.fz_mask_dic, input.shape[-2:])
|
||||||
|
input = input * mask
|
||||||
|
return module.fz_original_forward(input)
|
||||||
|
|
||||||
|
return forward_conv2d_zero_slicing
|
||||||
|
|
||||||
|
module.forward = make_forward(module)
|
||||||
|
|
||||||
|
# def forward_conv2d_zero_slicing(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
# print(self.__class__.__name__, "forward_conv2d_zero_slicing")
|
||||||
|
# print(self.enable_flag, self.mask_dic, input.shape[-2:])
|
||||||
|
# if self.fz_depth is None or not self.fz_enable_flag[0] or self.fz_mask_dic is None or len(self.fz_mask_dic) == 0:
|
||||||
|
# return self.original_forward(input)
|
||||||
|
|
||||||
|
# mask = get_mask_from_mask_dic(self.fz_mask_dic, input.shape[-2:])
|
||||||
|
# if mask is not None:
|
||||||
|
# input = input * mask
|
||||||
|
# return self.original_forward(input)
|
||||||
|
|
||||||
|
# for name, module in list(self.delegate.named_modules()):
|
||||||
|
# if isinstance(module, nn.Conv2d):
|
||||||
|
# if module.kernel_size == (3, 3):
|
||||||
|
# # replace Conv2d with Conv2dZeroSlicing
|
||||||
|
# new_conv2d = Conv2dZeroSlicing(
|
||||||
|
# module.in_channels,
|
||||||
|
# module.out_channels,
|
||||||
|
# module.kernel_size,
|
||||||
|
# module.stride,
|
||||||
|
# module.padding,
|
||||||
|
# module.dilation,
|
||||||
|
# module.groups,
|
||||||
|
# module.bias is not None,
|
||||||
|
# module.padding_mode,
|
||||||
|
# )
|
||||||
|
# new_conv2d.set_reference_for_enable_and_mask_dic(self.fz_enable_flag, self.fz_mask_dic)
|
||||||
|
# print(f"replace {name} with Conv2dZeroSlicing")
|
||||||
|
# setattr(self.delegate, name, new_conv2d)
|
||||||
|
|
||||||
|
# # copy parameters
|
||||||
|
# new_conv2d.weight = module.weight
|
||||||
|
# new_conv2d.bias = module.bias
|
||||||
|
|
||||||
# call original model's methods
|
# call original model's methods
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return getattr(self.delegate, name)
|
return getattr(self.delegate, name)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.delegate(*args, **kwargs)
|
return self.delegate(*args, **kwargs)
|
||||||
|
|
||||||
@@ -1154,6 +1282,32 @@ class InferSdxlUNet2DConditionModel:
|
|||||||
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
||||||
self.ds_ratio = ds_ratio
|
self.ds_ratio = ds_ratio
|
||||||
|
|
||||||
|
def set_flexible_zero_slicing(self, mask: torch.Tensor, depth: int, timesteps: int = None):
|
||||||
|
# mask is arbitrary shape, 0 for zero slicing.
|
||||||
|
if depth is None or depth < 0:
|
||||||
|
logger.info("Flexible zero slicing is disabled.")
|
||||||
|
self.fz_depth = None
|
||||||
|
self.fz_mask = None
|
||||||
|
self.fz_timesteps = None
|
||||||
|
self.fz_mask_dic.clear()
|
||||||
|
else:
|
||||||
|
logger.info(f"Flexible zero slicing is enabled: [depth={depth}]")
|
||||||
|
self.fz_depth = depth
|
||||||
|
self.fz_mask = mask
|
||||||
|
self.fz_timesteps = timesteps
|
||||||
|
self.fz_mask_dic.clear()
|
||||||
|
self.fz_mask_dic[(0, 0)] = mask.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
def set_dilated_conv(self, depth: int, timesteps: int = None):
|
||||||
|
if depth is None or depth < 0:
|
||||||
|
logger.info("Dilated Conv is disabled.")
|
||||||
|
self.dc_depth = None
|
||||||
|
self.dc_timesteps = None
|
||||||
|
else:
|
||||||
|
logger.info(f"Dilated Conv is enabled: [depth={depth}]")
|
||||||
|
self.dc_depth = depth
|
||||||
|
self.dc_timesteps = timesteps
|
||||||
|
|
||||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
|
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
|
||||||
@@ -1188,7 +1342,18 @@ class InferSdxlUNet2DConditionModel:
|
|||||||
# h = x.type(self.dtype)
|
# h = x.type(self.dtype)
|
||||||
h = x
|
h = x
|
||||||
|
|
||||||
|
self.fz_enable_flag[0] = False
|
||||||
|
|
||||||
for depth, module in enumerate(_self.input_blocks):
|
for depth, module in enumerate(_self.input_blocks):
|
||||||
|
# Dilated Conv
|
||||||
|
if self.dc_depth is not None:
|
||||||
|
self.dc_enable_flag[0] = depth >= self.dc_depth and timesteps[0] > self.dc_timesteps
|
||||||
|
|
||||||
|
# Flexible Zero Slicing
|
||||||
|
if self.fz_depth is not None:
|
||||||
|
self.fz_enable_flag[0] = depth >= self.fz_depth and timesteps[0] > self.fz_timesteps
|
||||||
|
# print(f"Flexible Zero Slicing: depth={depth}, timesteps={timesteps[0]}, enable={self.fz_enable_flag[0]}")
|
||||||
|
|
||||||
# Deep Shrink
|
# Deep Shrink
|
||||||
if self.ds_depth_1 is not None:
|
if self.ds_depth_1 is not None:
|
||||||
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
||||||
@@ -1208,7 +1373,16 @@ class InferSdxlUNet2DConditionModel:
|
|||||||
|
|
||||||
h = call_module(_self.middle_block, h, emb, context)
|
h = call_module(_self.middle_block, h, emb, context)
|
||||||
|
|
||||||
for module in _self.output_blocks:
|
for depth, module in enumerate(_self.output_blocks):
|
||||||
|
# Dilated Conv
|
||||||
|
if self.dc_depth is not None and len(_self.output_blocks) - depth <= self.dc_depth:
|
||||||
|
self.dc_enable_flag[0] = False
|
||||||
|
|
||||||
|
# Flexible Zero Slicing
|
||||||
|
if self.fz_depth is not None and len(self.output_blocks) - depth <= self.fz_depth:
|
||||||
|
self.fz_enable_flag[0] = False
|
||||||
|
# print(f"Flexible Zero Slicing: depth={depth}, timesteps={timesteps[0]}, enable={self.fz_enable_flag[0]}")
|
||||||
|
|
||||||
# Deep Shrink
|
# Deep Shrink
|
||||||
if self.ds_depth_1 is not None:
|
if self.ds_depth_1 is not None:
|
||||||
if hs[-1].shape[-2:] != h.shape[-2:]:
|
if hs[-1].shape[-2:] != h.shape[-2:]:
|
||||||
|
|||||||
Reference in New Issue
Block a user