mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
make separate U-Net for inference
This commit is contained in:
@@ -24,7 +24,7 @@
|
||||
|
||||
import math
|
||||
from types import SimpleNamespace
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
@@ -1013,31 +1013,6 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
|
||||
)
|
||||
|
||||
# Deep Shrink
|
||||
self.ds_depth_1 = None
|
||||
self.ds_depth_2 = None
|
||||
self.ds_timesteps_1 = None
|
||||
self.ds_timesteps_2 = None
|
||||
self.ds_ratio = None
|
||||
|
||||
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
||||
if ds_depth_1 is None:
|
||||
print("Deep Shrink is disabled.")
|
||||
self.ds_depth_1 = None
|
||||
self.ds_timesteps_1 = None
|
||||
self.ds_depth_2 = None
|
||||
self.ds_timesteps_2 = None
|
||||
self.ds_ratio = None
|
||||
else:
|
||||
print(
|
||||
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
||||
)
|
||||
self.ds_depth_1 = ds_depth_1
|
||||
self.ds_timesteps_1 = ds_timesteps_1
|
||||
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
||||
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
||||
self.ds_ratio = ds_ratio
|
||||
|
||||
# region diffusers compatibility
|
||||
def prepare_config(self):
|
||||
self.config = SimpleNamespace()
|
||||
@@ -1120,7 +1095,97 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
# h = x.type(self.dtype)
|
||||
h = x
|
||||
|
||||
for depth, module in enumerate(self.input_blocks):
|
||||
for module in self.input_blocks:
|
||||
h = call_module(module, h, emb, context)
|
||||
hs.append(h)
|
||||
|
||||
h = call_module(self.middle_block, h, emb, context)
|
||||
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = call_module(module, h, emb, context)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = call_module(self.out, h, emb, context)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class InferSdxlUNet2DConditionModel:
|
||||
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
|
||||
self.delegate = original_unet
|
||||
|
||||
# override original model's forward method: because forward is not called by `__call__`
|
||||
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
|
||||
self.delegate.forward = self.forward
|
||||
|
||||
# Deep Shrink
|
||||
self.ds_depth_1 = None
|
||||
self.ds_depth_2 = None
|
||||
self.ds_timesteps_1 = None
|
||||
self.ds_timesteps_2 = None
|
||||
self.ds_ratio = None
|
||||
|
||||
# call original model's methods
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.delegate, name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.delegate(*args, **kwargs)
|
||||
|
||||
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
||||
if ds_depth_1 is None:
|
||||
print("Deep Shrink is disabled.")
|
||||
self.ds_depth_1 = None
|
||||
self.ds_timesteps_1 = None
|
||||
self.ds_depth_2 = None
|
||||
self.ds_timesteps_2 = None
|
||||
self.ds_ratio = None
|
||||
else:
|
||||
print(
|
||||
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
||||
)
|
||||
self.ds_depth_1 = ds_depth_1
|
||||
self.ds_timesteps_1 = ds_timesteps_1
|
||||
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
||||
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
||||
self.ds_ratio = ds_ratio
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
r"""
|
||||
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
|
||||
"""
|
||||
_self = self.delegate
|
||||
|
||||
# broadcast timesteps to batch dimension
|
||||
timesteps = timesteps.expand(x.shape[0])
|
||||
|
||||
hs = []
|
||||
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
|
||||
t_emb = t_emb.to(x.dtype)
|
||||
emb = _self.time_embed(t_emb)
|
||||
|
||||
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
|
||||
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
|
||||
# assert x.dtype == _self.dtype
|
||||
emb = emb + _self.label_emb(y)
|
||||
|
||||
def call_module(module, h, emb, context):
|
||||
x = h
|
||||
for layer in module:
|
||||
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
|
||||
if isinstance(layer, ResnetBlock2D):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, Transformer2DModel):
|
||||
x = layer(x, context)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
# h = x.type(self.dtype)
|
||||
h = x
|
||||
|
||||
for depth, module in enumerate(_self.input_blocks):
|
||||
# Deep Shrink
|
||||
if self.ds_depth_1 is not None:
|
||||
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
||||
@@ -1138,9 +1203,9 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
h = call_module(module, h, emb, context)
|
||||
hs.append(h)
|
||||
|
||||
h = call_module(self.middle_block, h, emb, context)
|
||||
h = call_module(_self.middle_block, h, emb, context)
|
||||
|
||||
for module in self.output_blocks:
|
||||
for module in _self.output_blocks:
|
||||
# Deep Shrink
|
||||
if self.ds_depth_1 is not None:
|
||||
if hs[-1].shape[-2:] != h.shape[-2:]:
|
||||
@@ -1156,7 +1221,7 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
h = resize_like(h, x)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = call_module(self.out, h, emb, context)
|
||||
h = call_module(_self.out, h, emb, context)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
Reference in New Issue
Block a user