mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add Deep Shrink
This commit is contained in:
@@ -266,6 +266,23 @@ def get_timestep_embedding(
|
||||
return emb
|
||||
|
||||
|
||||
# Deep Shrink: We do not common this function, because minimize dependencies.
|
||||
def resize_like(x, target, mode="bicubic", align_corners=False):
|
||||
org_dtype = x.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
if x.shape[-2:] != target.shape[-2:]:
|
||||
if mode == "nearest":
|
||||
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
|
||||
else:
|
||||
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
|
||||
|
||||
if org_dtype == torch.bfloat16:
|
||||
x = x.to(org_dtype)
|
||||
return x
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
if self.weight.dtype != torch.float32:
|
||||
@@ -996,6 +1013,31 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
|
||||
)
|
||||
|
||||
# Deep Shrink
|
||||
self.ds_depth_1 = None
|
||||
self.ds_depth_2 = None
|
||||
self.ds_timesteps_1 = None
|
||||
self.ds_timesteps_2 = None
|
||||
self.ds_ratio = None
|
||||
|
||||
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
||||
if ds_depth_1 is None:
|
||||
print("Deep Shrink is disabled.")
|
||||
self.ds_depth_1 = None
|
||||
self.ds_timesteps_1 = None
|
||||
self.ds_depth_2 = None
|
||||
self.ds_timesteps_2 = None
|
||||
self.ds_ratio = None
|
||||
else:
|
||||
print(
|
||||
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
||||
)
|
||||
self.ds_depth_1 = ds_depth_1
|
||||
self.ds_timesteps_1 = ds_timesteps_1
|
||||
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
||||
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
||||
self.ds_ratio = ds_ratio
|
||||
|
||||
# region diffusers compatibility
|
||||
def prepare_config(self):
|
||||
self.config = SimpleNamespace()
|
||||
@@ -1077,16 +1119,42 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
|
||||
# h = x.type(self.dtype)
|
||||
h = x
|
||||
for module in self.input_blocks:
|
||||
|
||||
for depth, module in enumerate(self.input_blocks):
|
||||
# Deep Shrink
|
||||
if self.ds_depth_1 is not None:
|
||||
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
||||
self.ds_depth_2 is not None
|
||||
and depth == self.ds_depth_2
|
||||
and timesteps[0] < self.ds_timesteps_1
|
||||
and timesteps[0] >= self.ds_timesteps_2
|
||||
):
|
||||
# print("downsample", h.shape, self.ds_ratio)
|
||||
org_dtype = h.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
h = h.to(torch.float32)
|
||||
h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
|
||||
|
||||
h = call_module(module, h, emb, context)
|
||||
hs.append(h)
|
||||
|
||||
h = call_module(self.middle_block, h, emb, context)
|
||||
|
||||
for module in self.output_blocks:
|
||||
# Deep Shrink
|
||||
if self.ds_depth_1 is not None:
|
||||
if hs[-1].shape[-2:] != h.shape[-2:]:
|
||||
# print("upsample", h.shape, hs[-1].shape)
|
||||
h = resize_like(h, hs[-1])
|
||||
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = call_module(module, h, emb, context)
|
||||
|
||||
# Deep Shrink: in case of depth 0
|
||||
if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]:
|
||||
# print("upsample", h.shape, x.shape)
|
||||
h = resize_like(h, x)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = call_module(self.out, h, emb, context)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user