scale crafter

This commit is contained in:
Kohya S
2024-02-25 08:53:43 +09:00
parent bae116a031
commit dbe78a8638
2 changed files with 70 additions and 8 deletions

View File

@@ -1171,6 +1171,32 @@ class InferSdxlUNet2DConditionModel:
self.ds_timesteps_2 = None
self.ds_ratio = None
# Dilated Conv
self.dc_depth = None
self.dc_timesteps = None
self.dc_enable_flag = [False]
for name, module in self.delegate.named_modules():
if isinstance(module, nn.Conv2d):
if module.kernel_size == (3, 3) and module.dilation == (1, 1):
module.dc_enable_flag = self.dc_enable_flag
# replace forward method
module.dc_original_forward = module.forward
def make_forward_dilated_conv(module):
def forward_conv2d_dilated_conv(input: torch.Tensor) -> torch.Tensor:
if module.dc_enable_flag[0]:
module.dilation = (1, 2)
module.padding = (1, 2)
else:
module.dilation = (1, 1)
module.padding = (1, 1)
return module.dc_original_forward(input)
return forward_conv2d_dilated_conv
module.forward = make_forward_dilated_conv(module)
# flexible zero slicing
self.fz_depth = None
self.fz_enable_flag = [False]
@@ -1178,20 +1204,20 @@ class InferSdxlUNet2DConditionModel:
for name, module in self.delegate.named_modules():
if isinstance(module, nn.Conv2d):
if module.kernel_size == (3, 3):
module.enable_flag = self.fz_enable_flag
module.mask_dic = self.fz_mask_dic
module.fz_enable_flag = self.fz_enable_flag
module.fz_mask_dic = self.fz_mask_dic
# replace forward method
module.original_forward = module.forward
module.fz_original_forward = module.forward
def make_forward(module):
def forward_conv2d_zero_slicing(input: torch.Tensor) -> torch.Tensor:
if not module.enable_flag[0] or len(module.mask_dic) == 0:
return module.original_forward(input)
if not module.fz_enable_flag[0] or len(module.fz_mask_dic) == 0:
return module.fz_original_forward(input)
mask = get_mask_from_mask_dic(module.mask_dic, input.shape[-2:])
mask = get_mask_from_mask_dic(module.fz_mask_dic, input.shape[-2:])
input = input * mask
return module.original_forward(input)
return module.fz_original_forward(input)
return forward_conv2d_zero_slicing
@@ -1272,6 +1298,16 @@ class InferSdxlUNet2DConditionModel:
self.fz_mask_dic.clear()
self.fz_mask_dic[(0, 0)] = mask.unsqueeze(0).unsqueeze(0)
def set_dilated_conv(self, depth: int, timesteps: int = None):
if depth is None or depth < 0:
logger.info("Dilated Conv is disabled.")
self.dc_depth = None
self.dc_timesteps = None
else:
logger.info(f"Dilated Conv is enabled: [depth={depth}]")
self.dc_depth = depth
self.dc_timesteps = timesteps
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
r"""
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
@@ -1309,6 +1345,10 @@ class InferSdxlUNet2DConditionModel:
self.fz_enable_flag[0] = False
for depth, module in enumerate(_self.input_blocks):
# Dilated Conv
if self.dc_depth is not None:
self.dc_enable_flag[0] = depth >= self.dc_depth and timesteps[0] > self.dc_timesteps
# Flexible Zero Slicing
if self.fz_depth is not None:
self.fz_enable_flag[0] = depth >= self.fz_depth and timesteps[0] > self.fz_timesteps
@@ -1334,6 +1374,10 @@ class InferSdxlUNet2DConditionModel:
h = call_module(_self.middle_block, h, emb, context)
for depth, module in enumerate(_self.output_blocks):
# Dilated Conv
if self.dc_depth is not None and len(_self.output_blocks) - depth <= self.dc_depth:
self.dc_enable_flag[0] = False
# Flexible Zero Slicing
if self.fz_depth is not None and len(self.output_blocks) - depth <= self.fz_depth:
self.fz_enable_flag[0] = False