make separate U-Net for inference

This commit is contained in:
Kohya S
2023-11-26 18:11:30 +09:00
parent fc8649d80f
commit c61e3bf4c9
4 changed files with 366 additions and 82 deletions

View File

@@ -24,7 +24,7 @@
import math
from types import SimpleNamespace
from typing import Optional
from typing import Any, Optional
import torch
import torch.utils.checkpoint
from torch import nn
@@ -1013,31 +1013,6 @@ class SdxlUNet2DConditionModel(nn.Module):
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
)
# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
print("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
print(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
# region diffusers compatibility
def prepare_config(self):
self.config = SimpleNamespace()
@@ -1120,7 +1095,97 @@ class SdxlUNet2DConditionModel(nn.Module):
# h = x.type(self.dtype)
h = x
for depth, module in enumerate(self.input_blocks):
for module in self.input_blocks:
h = call_module(module, h, emb, context)
hs.append(h)
h = call_module(self.middle_block, h, emb, context)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = call_module(module, h, emb, context)
h = h.type(x.dtype)
h = call_module(self.out, h, emb, context)
return h
class InferSdxlUNet2DConditionModel:
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
self.delegate = original_unet
# override original model's forward method: because forward is not called by `__call__`
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
self.delegate.forward = self.forward
# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
# call original model's methods
def __getattr__(self, name):
return getattr(self.delegate, name)
def __call__(self, *args, **kwargs):
return self.delegate(*args, **kwargs)
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
print("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
print(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
r"""
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
"""
_self = self.delegate
# broadcast timesteps to batch dimension
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = _self.time_embed(t_emb)
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
# assert x.dtype == _self.dtype
emb = emb + _self.label_emb(y)
def call_module(module, h, emb, context):
x = h
for layer in module:
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
if isinstance(layer, ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, Transformer2DModel):
x = layer(x, context)
else:
x = layer(x)
return x
# h = x.type(self.dtype)
h = x
for depth, module in enumerate(_self.input_blocks):
# Deep Shrink
if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
@@ -1138,9 +1203,9 @@ class SdxlUNet2DConditionModel(nn.Module):
h = call_module(module, h, emb, context)
hs.append(h)
h = call_module(self.middle_block, h, emb, context)
h = call_module(_self.middle_block, h, emb, context)
for module in self.output_blocks:
for module in _self.output_blocks:
# Deep Shrink
if self.ds_depth_1 is not None:
if hs[-1].shape[-2:] != h.shape[-2:]:
@@ -1156,7 +1221,7 @@ class SdxlUNet2DConditionModel(nn.Module):
h = resize_like(h, x)
h = h.type(x.dtype)
h = call_module(self.out, h, emb, context)
h = call_module(_self.out, h, emb, context)
return h