mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
scale crafter
This commit is contained in:
20
gen_img.py
20
gen_img.py
@@ -2138,7 +2138,7 @@ def main(args):
|
||||
|
||||
# 輪郭を新しい配列に描画
|
||||
cv2.drawContours(fz_mask, contours, -1, (0, 0, 0), 1)
|
||||
|
||||
|
||||
fz_mask = fz_mask.astype(np.float32) / 255.0
|
||||
fz_mask = fz_mask[:, :, 0]
|
||||
fz_mask = torch.from_numpy(fz_mask).to(dtype).to(device)
|
||||
@@ -2146,6 +2146,10 @@ def main(args):
|
||||
# only for sdxl
|
||||
unet.set_flexible_zero_slicing(fz_mask, args.flexible_zero_slicing_depth, args.flexible_zero_slicing_timesteps)
|
||||
|
||||
# Dilated Conv Hires fix
|
||||
if args.dilated_conv_hires_fix_depth is not None:
|
||||
unet.set_dilated_conv(args.dilated_conv_hires_fix_depth, args.dilated_conv_hires_fix_timesteps)
|
||||
|
||||
# 画像サイズにオプション指定があるときはリサイズする
|
||||
if args.W is not None and args.H is not None:
|
||||
# highres fix を考慮に入れる
|
||||
@@ -3365,6 +3369,20 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="timesteps for flexible zero slicing / flexible zero slicingのtimesteps",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dilated_conv_hires_fix_depth",
|
||||
type=int,
|
||||
default=None,
|
||||
help="depth for dilated conv hires fix / dilated conv hires fixのdepth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dilated_conv_hires_fix_timesteps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="timesteps for dilated conv hires fix / dilated conv hires fixのtimesteps",
|
||||
)
|
||||
|
||||
# # parser.add_argument(
|
||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||
# )
|
||||
|
||||
@@ -1171,6 +1171,32 @@ class InferSdxlUNet2DConditionModel:
|
||||
self.ds_timesteps_2 = None
|
||||
self.ds_ratio = None
|
||||
|
||||
# Dilated Conv
|
||||
self.dc_depth = None
|
||||
self.dc_timesteps = None
|
||||
self.dc_enable_flag = [False]
|
||||
for name, module in self.delegate.named_modules():
|
||||
if isinstance(module, nn.Conv2d):
|
||||
if module.kernel_size == (3, 3) and module.dilation == (1, 1):
|
||||
module.dc_enable_flag = self.dc_enable_flag
|
||||
|
||||
# replace forward method
|
||||
module.dc_original_forward = module.forward
|
||||
|
||||
def make_forward_dilated_conv(module):
|
||||
def forward_conv2d_dilated_conv(input: torch.Tensor) -> torch.Tensor:
|
||||
if module.dc_enable_flag[0]:
|
||||
module.dilation = (1, 2)
|
||||
module.padding = (1, 2)
|
||||
else:
|
||||
module.dilation = (1, 1)
|
||||
module.padding = (1, 1)
|
||||
return module.dc_original_forward(input)
|
||||
|
||||
return forward_conv2d_dilated_conv
|
||||
|
||||
module.forward = make_forward_dilated_conv(module)
|
||||
|
||||
# flexible zero slicing
|
||||
self.fz_depth = None
|
||||
self.fz_enable_flag = [False]
|
||||
@@ -1178,20 +1204,20 @@ class InferSdxlUNet2DConditionModel:
|
||||
for name, module in self.delegate.named_modules():
|
||||
if isinstance(module, nn.Conv2d):
|
||||
if module.kernel_size == (3, 3):
|
||||
module.enable_flag = self.fz_enable_flag
|
||||
module.mask_dic = self.fz_mask_dic
|
||||
module.fz_enable_flag = self.fz_enable_flag
|
||||
module.fz_mask_dic = self.fz_mask_dic
|
||||
|
||||
# replace forward method
|
||||
module.original_forward = module.forward
|
||||
module.fz_original_forward = module.forward
|
||||
|
||||
def make_forward(module):
|
||||
def forward_conv2d_zero_slicing(input: torch.Tensor) -> torch.Tensor:
|
||||
if not module.enable_flag[0] or len(module.mask_dic) == 0:
|
||||
return module.original_forward(input)
|
||||
if not module.fz_enable_flag[0] or len(module.fz_mask_dic) == 0:
|
||||
return module.fz_original_forward(input)
|
||||
|
||||
mask = get_mask_from_mask_dic(module.mask_dic, input.shape[-2:])
|
||||
mask = get_mask_from_mask_dic(module.fz_mask_dic, input.shape[-2:])
|
||||
input = input * mask
|
||||
return module.original_forward(input)
|
||||
return module.fz_original_forward(input)
|
||||
|
||||
return forward_conv2d_zero_slicing
|
||||
|
||||
@@ -1272,6 +1298,16 @@ class InferSdxlUNet2DConditionModel:
|
||||
self.fz_mask_dic.clear()
|
||||
self.fz_mask_dic[(0, 0)] = mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
def set_dilated_conv(self, depth: int, timesteps: int = None):
|
||||
if depth is None or depth < 0:
|
||||
logger.info("Dilated Conv is disabled.")
|
||||
self.dc_depth = None
|
||||
self.dc_timesteps = None
|
||||
else:
|
||||
logger.info(f"Dilated Conv is enabled: [depth={depth}]")
|
||||
self.dc_depth = depth
|
||||
self.dc_timesteps = timesteps
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
r"""
|
||||
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
|
||||
@@ -1309,6 +1345,10 @@ class InferSdxlUNet2DConditionModel:
|
||||
self.fz_enable_flag[0] = False
|
||||
|
||||
for depth, module in enumerate(_self.input_blocks):
|
||||
# Dilated Conv
|
||||
if self.dc_depth is not None:
|
||||
self.dc_enable_flag[0] = depth >= self.dc_depth and timesteps[0] > self.dc_timesteps
|
||||
|
||||
# Flexible Zero Slicing
|
||||
if self.fz_depth is not None:
|
||||
self.fz_enable_flag[0] = depth >= self.fz_depth and timesteps[0] > self.fz_timesteps
|
||||
@@ -1334,6 +1374,10 @@ class InferSdxlUNet2DConditionModel:
|
||||
h = call_module(_self.middle_block, h, emb, context)
|
||||
|
||||
for depth, module in enumerate(_self.output_blocks):
|
||||
# Dilated Conv
|
||||
if self.dc_depth is not None and len(_self.output_blocks) - depth <= self.dc_depth:
|
||||
self.dc_enable_flag[0] = False
|
||||
|
||||
# Flexible Zero Slicing
|
||||
if self.fz_depth is not None and len(self.output_blocks) - depth <= self.fz_depth:
|
||||
self.fz_enable_flag[0] = False
|
||||
|
||||
Reference in New Issue
Block a user