mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
make separate U-Net for inference
This commit is contained in:
@@ -105,7 +105,7 @@ import library.train_util as train_util
|
|||||||
from networks.lora import LoRANetwork
|
from networks.lora import LoRANetwork
|
||||||
import tools.original_control_net as original_control_net
|
import tools.original_control_net as original_control_net
|
||||||
from tools.original_control_net import ControlNetInfo
|
from tools.original_control_net import ControlNetInfo
|
||||||
from library.original_unet import UNet2DConditionModel
|
from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel
|
||||||
from library.original_unet import FlashAttentionFunction
|
from library.original_unet import FlashAttentionFunction
|
||||||
|
|
||||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||||
@@ -378,7 +378,7 @@ class PipelineLike:
|
|||||||
vae: AutoencoderKL,
|
vae: AutoencoderKL,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
unet: UNet2DConditionModel,
|
unet: InferUNet2DConditionModel,
|
||||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||||
clip_skip: int,
|
clip_skip: int,
|
||||||
clip_model: CLIPModel,
|
clip_model: CLIPModel,
|
||||||
@@ -2196,6 +2196,7 @@ def main(args):
|
|||||||
)
|
)
|
||||||
original_unet.load_state_dict(unet.state_dict())
|
original_unet.load_state_dict(unet.state_dict())
|
||||||
unet = original_unet
|
unet = original_unet
|
||||||
|
unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet)
|
||||||
|
|
||||||
# VAEを読み込む
|
# VAEを読み込む
|
||||||
if args.vae is not None:
|
if args.vae is not None:
|
||||||
@@ -2352,13 +2353,20 @@ def main(args):
|
|||||||
vae = sli_vae
|
vae = sli_vae
|
||||||
del sli_vae
|
del sli_vae
|
||||||
vae.to(dtype).to(device)
|
vae.to(dtype).to(device)
|
||||||
|
vae.eval()
|
||||||
|
|
||||||
text_encoder.to(dtype).to(device)
|
text_encoder.to(dtype).to(device)
|
||||||
unet.to(dtype).to(device)
|
unet.to(dtype).to(device)
|
||||||
|
|
||||||
|
text_encoder.eval()
|
||||||
|
unet.eval()
|
||||||
|
|
||||||
if clip_model is not None:
|
if clip_model is not None:
|
||||||
clip_model.to(dtype).to(device)
|
clip_model.to(dtype).to(device)
|
||||||
|
clip_model.eval()
|
||||||
if vgg16_model is not None:
|
if vgg16_model is not None:
|
||||||
vgg16_model.to(dtype).to(device)
|
vgg16_model.to(dtype).to(device)
|
||||||
|
vgg16_model.eval()
|
||||||
|
|
||||||
# networkを組み込む
|
# networkを組み込む
|
||||||
if args.network_module:
|
if args.network_module:
|
||||||
|
|||||||
@@ -1148,10 +1148,6 @@ class UpBlock2D(nn.Module):
|
|||||||
res_hidden_states = res_hidden_states_tuple[-1]
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
|
||||||
# Deep Shrink
|
|
||||||
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
|
||||||
hidden_states = resize_like(hidden_states, res_hidden_states)
|
|
||||||
|
|
||||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
|
||||||
if self.training and self.gradient_checkpointing:
|
if self.training and self.gradient_checkpointing:
|
||||||
@@ -1244,10 +1240,6 @@ class CrossAttnUpBlock2D(nn.Module):
|
|||||||
res_hidden_states = res_hidden_states_tuple[-1]
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
|
||||||
# Deep Shrink
|
|
||||||
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
|
||||||
hidden_states = resize_like(hidden_states, res_hidden_states)
|
|
||||||
|
|
||||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
|
||||||
if self.training and self.gradient_checkpointing:
|
if self.training and self.gradient_checkpointing:
|
||||||
@@ -1444,31 +1436,6 @@ class UNet2DConditionModel(nn.Module):
|
|||||||
self.conv_act = nn.SiLU()
|
self.conv_act = nn.SiLU()
|
||||||
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
||||||
|
|
||||||
# Deep Shrink
|
|
||||||
self.ds_depth_1 = None
|
|
||||||
self.ds_depth_2 = None
|
|
||||||
self.ds_timesteps_1 = None
|
|
||||||
self.ds_timesteps_2 = None
|
|
||||||
self.ds_ratio = None
|
|
||||||
|
|
||||||
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
|
||||||
if ds_depth_1 is None:
|
|
||||||
print("Deep Shrink is disabled.")
|
|
||||||
self.ds_depth_1 = None
|
|
||||||
self.ds_timesteps_1 = None
|
|
||||||
self.ds_depth_2 = None
|
|
||||||
self.ds_timesteps_2 = None
|
|
||||||
self.ds_ratio = None
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
|
||||||
)
|
|
||||||
self.ds_depth_1 = ds_depth_1
|
|
||||||
self.ds_timesteps_1 = ds_timesteps_1
|
|
||||||
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
|
||||||
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
|
||||||
self.ds_ratio = ds_ratio
|
|
||||||
|
|
||||||
# region diffusers compatibility
|
# region diffusers compatibility
|
||||||
def prepare_config(self):
|
def prepare_config(self):
|
||||||
self.config = SimpleNamespace()
|
self.config = SimpleNamespace()
|
||||||
@@ -1572,20 +1539,7 @@ class UNet2DConditionModel(nn.Module):
|
|||||||
sample = self.conv_in(sample)
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
down_block_res_samples = (sample,)
|
down_block_res_samples = (sample,)
|
||||||
for depth, downsample_block in enumerate(self.down_blocks):
|
for downsample_block in self.down_blocks:
|
||||||
# Deep Shrink
|
|
||||||
if self.ds_depth_1 is not None:
|
|
||||||
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
|
||||||
self.ds_depth_2 is not None
|
|
||||||
and depth == self.ds_depth_2
|
|
||||||
and timesteps[0] < self.ds_timesteps_1
|
|
||||||
and timesteps[0] >= self.ds_timesteps_2
|
|
||||||
):
|
|
||||||
org_dtype = sample.dtype
|
|
||||||
if org_dtype == torch.bfloat16:
|
|
||||||
sample = sample.to(torch.float32)
|
|
||||||
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
|
|
||||||
|
|
||||||
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
||||||
# まあこちらのほうがわかりやすいかもしれない
|
# まあこちらのほうがわかりやすいかもしれない
|
||||||
if downsample_block.has_cross_attention:
|
if downsample_block.has_cross_attention:
|
||||||
@@ -1668,3 +1622,255 @@ class UNet2DConditionModel(nn.Module):
|
|||||||
timesteps = timesteps.expand(sample.shape[0])
|
timesteps = timesteps.expand(sample.shape[0])
|
||||||
|
|
||||||
return timesteps
|
return timesteps
|
||||||
|
|
||||||
|
|
||||||
|
class InferUNet2DConditionModel:
|
||||||
|
def __init__(self, original_unet: UNet2DConditionModel):
|
||||||
|
self.delegate = original_unet
|
||||||
|
|
||||||
|
# override original model's forward method: because forward is not called by `__call__`
|
||||||
|
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
|
||||||
|
self.delegate.forward = self.forward
|
||||||
|
|
||||||
|
# override original model's up blocks' forward method
|
||||||
|
for up_block in self.delegate.up_blocks:
|
||||||
|
if up_block.__class__.__name__ == "UpBlock2D":
|
||||||
|
|
||||||
|
def resnet_wrapper(func, block):
|
||||||
|
def forward(*args, **kwargs):
|
||||||
|
return func(block, *args, **kwargs)
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
|
||||||
|
|
||||||
|
elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
||||||
|
|
||||||
|
def cross_attn_up_wrapper(func, block):
|
||||||
|
def forward(*args, **kwargs):
|
||||||
|
return func(block, *args, **kwargs)
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
self.ds_depth_1 = None
|
||||||
|
self.ds_depth_2 = None
|
||||||
|
self.ds_timesteps_1 = None
|
||||||
|
self.ds_timesteps_2 = None
|
||||||
|
self.ds_ratio = None
|
||||||
|
|
||||||
|
# call original model's methods
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self.delegate, name)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.delegate(*args, **kwargs)
|
||||||
|
|
||||||
|
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
||||||
|
if ds_depth_1 is None:
|
||||||
|
print("Deep Shrink is disabled.")
|
||||||
|
self.ds_depth_1 = None
|
||||||
|
self.ds_timesteps_1 = None
|
||||||
|
self.ds_depth_2 = None
|
||||||
|
self.ds_timesteps_2 = None
|
||||||
|
self.ds_ratio = None
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
||||||
|
)
|
||||||
|
self.ds_depth_1 = ds_depth_1
|
||||||
|
self.ds_timesteps_1 = ds_timesteps_1
|
||||||
|
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
||||||
|
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
||||||
|
self.ds_ratio = ds_ratio
|
||||||
|
|
||||||
|
def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
||||||
|
for resnet in _self.resnets:
|
||||||
|
# pop res hidden states
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
||||||
|
hidden_states = resize_like(hidden_states, res_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
|
||||||
|
if _self.upsamplers is not None:
|
||||||
|
for upsampler in _self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def cross_attn_up_block_forward(
|
||||||
|
self,
|
||||||
|
_self,
|
||||||
|
hidden_states,
|
||||||
|
res_hidden_states_tuple,
|
||||||
|
temb=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
upsample_size=None,
|
||||||
|
):
|
||||||
|
for resnet, attn in zip(_self.resnets, _self.attentions):
|
||||||
|
# pop res hidden states
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
||||||
|
hidden_states = resize_like(hidden_states, res_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||||
|
|
||||||
|
if _self.upsamplers is not None:
|
||||||
|
for upsampler in _self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
class_labels: Optional[torch.Tensor] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[Dict, Tuple]:
|
||||||
|
r"""
|
||||||
|
current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
|
||||||
|
"""
|
||||||
|
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||||
|
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||||
|
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a dict instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`SampleOutput` or `tuple`:
|
||||||
|
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_self = self.delegate
|
||||||
|
|
||||||
|
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||||
|
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
||||||
|
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||||
|
# on the fly if necessary.
|
||||||
|
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
|
||||||
|
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
|
||||||
|
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
|
||||||
|
default_overall_up_factor = 2**_self.num_upsamplers
|
||||||
|
|
||||||
|
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||||
|
# 64で割り切れないときはupsamplerにサイズを伝える
|
||||||
|
forward_upsample_size = False
|
||||||
|
upsample_size = None
|
||||||
|
|
||||||
|
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||||
|
# logger.info("Forward upsample size to force interpolation output size.")
|
||||||
|
forward_upsample_size = True
|
||||||
|
|
||||||
|
# 1. time
|
||||||
|
timesteps = timestep
|
||||||
|
timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
|
||||||
|
|
||||||
|
t_emb = _self.time_proj(timesteps)
|
||||||
|
|
||||||
|
# timesteps does not contain any weights and will always return f32 tensors
|
||||||
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||||
|
# there might be better ways to encapsulate this.
|
||||||
|
# timestepsは重みを含まないので常にfloat32のテンソルを返す
|
||||||
|
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
|
||||||
|
# time_projでキャストしておけばいいんじゃね?
|
||||||
|
t_emb = t_emb.to(dtype=_self.dtype)
|
||||||
|
emb = _self.time_embedding(t_emb)
|
||||||
|
|
||||||
|
# 2. pre-process
|
||||||
|
sample = _self.conv_in(sample)
|
||||||
|
|
||||||
|
down_block_res_samples = (sample,)
|
||||||
|
for depth, downsample_block in enumerate(_self.down_blocks):
|
||||||
|
# Deep Shrink
|
||||||
|
if self.ds_depth_1 is not None:
|
||||||
|
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
||||||
|
self.ds_depth_2 is not None
|
||||||
|
and depth == self.ds_depth_2
|
||||||
|
and timesteps[0] < self.ds_timesteps_1
|
||||||
|
and timesteps[0] >= self.ds_timesteps_2
|
||||||
|
):
|
||||||
|
org_dtype = sample.dtype
|
||||||
|
if org_dtype == torch.bfloat16:
|
||||||
|
sample = sample.to(torch.float32)
|
||||||
|
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
|
||||||
|
|
||||||
|
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
||||||
|
# まあこちらのほうがわかりやすいかもしれない
|
||||||
|
if downsample_block.has_cross_attention:
|
||||||
|
sample, res_samples = downsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||||
|
|
||||||
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
|
# skip connectionにControlNetの出力を追加する
|
||||||
|
if down_block_additional_residuals is not None:
|
||||||
|
down_block_res_samples = list(down_block_res_samples)
|
||||||
|
for i in range(len(down_block_res_samples)):
|
||||||
|
down_block_res_samples[i] += down_block_additional_residuals[i]
|
||||||
|
down_block_res_samples = tuple(down_block_res_samples)
|
||||||
|
|
||||||
|
# 4. mid
|
||||||
|
sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
# ControlNetの出力を追加する
|
||||||
|
if mid_block_additional_residual is not None:
|
||||||
|
sample += mid_block_additional_residual
|
||||||
|
|
||||||
|
# 5. up
|
||||||
|
for i, upsample_block in enumerate(_self.up_blocks):
|
||||||
|
is_final_block = i == len(_self.up_blocks) - 1
|
||||||
|
|
||||||
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||||
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
|
||||||
|
|
||||||
|
# if we have not reached the final block and need to forward the upsample size, we do it here
|
||||||
|
# 前述のように最後のブロック以外ではupsample_sizeを伝える
|
||||||
|
if not is_final_block and forward_upsample_size:
|
||||||
|
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||||
|
|
||||||
|
if upsample_block.has_cross_attention:
|
||||||
|
sample = upsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample = upsample_block(
|
||||||
|
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. post-process
|
||||||
|
sample = _self.conv_norm_out(sample)
|
||||||
|
sample = _self.conv_act(sample)
|
||||||
|
sample = _self.conv_out(sample)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sample,)
|
||||||
|
|
||||||
|
return SampleOutput(sample=sample)
|
||||||
|
|||||||
@@ -24,7 +24,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -1013,31 +1013,6 @@ class SdxlUNet2DConditionModel(nn.Module):
|
|||||||
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
|
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Deep Shrink
|
|
||||||
self.ds_depth_1 = None
|
|
||||||
self.ds_depth_2 = None
|
|
||||||
self.ds_timesteps_1 = None
|
|
||||||
self.ds_timesteps_2 = None
|
|
||||||
self.ds_ratio = None
|
|
||||||
|
|
||||||
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
|
||||||
if ds_depth_1 is None:
|
|
||||||
print("Deep Shrink is disabled.")
|
|
||||||
self.ds_depth_1 = None
|
|
||||||
self.ds_timesteps_1 = None
|
|
||||||
self.ds_depth_2 = None
|
|
||||||
self.ds_timesteps_2 = None
|
|
||||||
self.ds_ratio = None
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
|
||||||
)
|
|
||||||
self.ds_depth_1 = ds_depth_1
|
|
||||||
self.ds_timesteps_1 = ds_timesteps_1
|
|
||||||
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
|
||||||
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
|
||||||
self.ds_ratio = ds_ratio
|
|
||||||
|
|
||||||
# region diffusers compatibility
|
# region diffusers compatibility
|
||||||
def prepare_config(self):
|
def prepare_config(self):
|
||||||
self.config = SimpleNamespace()
|
self.config = SimpleNamespace()
|
||||||
@@ -1120,7 +1095,97 @@ class SdxlUNet2DConditionModel(nn.Module):
|
|||||||
# h = x.type(self.dtype)
|
# h = x.type(self.dtype)
|
||||||
h = x
|
h = x
|
||||||
|
|
||||||
for depth, module in enumerate(self.input_blocks):
|
for module in self.input_blocks:
|
||||||
|
h = call_module(module, h, emb, context)
|
||||||
|
hs.append(h)
|
||||||
|
|
||||||
|
h = call_module(self.middle_block, h, emb, context)
|
||||||
|
|
||||||
|
for module in self.output_blocks:
|
||||||
|
h = torch.cat([h, hs.pop()], dim=1)
|
||||||
|
h = call_module(module, h, emb, context)
|
||||||
|
|
||||||
|
h = h.type(x.dtype)
|
||||||
|
h = call_module(self.out, h, emb, context)
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class InferSdxlUNet2DConditionModel:
|
||||||
|
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
|
||||||
|
self.delegate = original_unet
|
||||||
|
|
||||||
|
# override original model's forward method: because forward is not called by `__call__`
|
||||||
|
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
|
||||||
|
self.delegate.forward = self.forward
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
self.ds_depth_1 = None
|
||||||
|
self.ds_depth_2 = None
|
||||||
|
self.ds_timesteps_1 = None
|
||||||
|
self.ds_timesteps_2 = None
|
||||||
|
self.ds_ratio = None
|
||||||
|
|
||||||
|
# call original model's methods
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self.delegate, name)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.delegate(*args, **kwargs)
|
||||||
|
|
||||||
|
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
||||||
|
if ds_depth_1 is None:
|
||||||
|
print("Deep Shrink is disabled.")
|
||||||
|
self.ds_depth_1 = None
|
||||||
|
self.ds_timesteps_1 = None
|
||||||
|
self.ds_depth_2 = None
|
||||||
|
self.ds_timesteps_2 = None
|
||||||
|
self.ds_ratio = None
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
||||||
|
)
|
||||||
|
self.ds_depth_1 = ds_depth_1
|
||||||
|
self.ds_timesteps_1 = ds_timesteps_1
|
||||||
|
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
||||||
|
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
||||||
|
self.ds_ratio = ds_ratio
|
||||||
|
|
||||||
|
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||||
|
r"""
|
||||||
|
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
|
||||||
|
"""
|
||||||
|
_self = self.delegate
|
||||||
|
|
||||||
|
# broadcast timesteps to batch dimension
|
||||||
|
timesteps = timesteps.expand(x.shape[0])
|
||||||
|
|
||||||
|
hs = []
|
||||||
|
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
|
||||||
|
t_emb = t_emb.to(x.dtype)
|
||||||
|
emb = _self.time_embed(t_emb)
|
||||||
|
|
||||||
|
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
|
||||||
|
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
|
||||||
|
# assert x.dtype == _self.dtype
|
||||||
|
emb = emb + _self.label_emb(y)
|
||||||
|
|
||||||
|
def call_module(module, h, emb, context):
|
||||||
|
x = h
|
||||||
|
for layer in module:
|
||||||
|
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
|
||||||
|
if isinstance(layer, ResnetBlock2D):
|
||||||
|
x = layer(x, emb)
|
||||||
|
elif isinstance(layer, Transformer2DModel):
|
||||||
|
x = layer(x, context)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# h = x.type(self.dtype)
|
||||||
|
h = x
|
||||||
|
|
||||||
|
for depth, module in enumerate(_self.input_blocks):
|
||||||
# Deep Shrink
|
# Deep Shrink
|
||||||
if self.ds_depth_1 is not None:
|
if self.ds_depth_1 is not None:
|
||||||
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
||||||
@@ -1138,9 +1203,9 @@ class SdxlUNet2DConditionModel(nn.Module):
|
|||||||
h = call_module(module, h, emb, context)
|
h = call_module(module, h, emb, context)
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
|
|
||||||
h = call_module(self.middle_block, h, emb, context)
|
h = call_module(_self.middle_block, h, emb, context)
|
||||||
|
|
||||||
for module in self.output_blocks:
|
for module in _self.output_blocks:
|
||||||
# Deep Shrink
|
# Deep Shrink
|
||||||
if self.ds_depth_1 is not None:
|
if self.ds_depth_1 is not None:
|
||||||
if hs[-1].shape[-2:] != h.shape[-2:]:
|
if hs[-1].shape[-2:] != h.shape[-2:]:
|
||||||
@@ -1156,7 +1221,7 @@ class SdxlUNet2DConditionModel(nn.Module):
|
|||||||
h = resize_like(h, x)
|
h = resize_like(h, x)
|
||||||
|
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
h = call_module(self.out, h, emb, context)
|
h = call_module(_self.out, h, emb, context)
|
||||||
|
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ import library.train_util as train_util
|
|||||||
import library.sdxl_model_util as sdxl_model_util
|
import library.sdxl_model_util as sdxl_model_util
|
||||||
import library.sdxl_train_util as sdxl_train_util
|
import library.sdxl_train_util as sdxl_train_util
|
||||||
from networks.lora import LoRANetwork
|
from networks.lora import LoRANetwork
|
||||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
|
||||||
from library.original_unet import FlashAttentionFunction
|
from library.original_unet import FlashAttentionFunction
|
||||||
from networks.control_net_lllite import ControlNetLLLite
|
from networks.control_net_lllite import ControlNetLLLite
|
||||||
|
|
||||||
@@ -290,7 +290,7 @@ class PipelineLike:
|
|||||||
vae: AutoencoderKL,
|
vae: AutoencoderKL,
|
||||||
text_encoders: List[CLIPTextModel],
|
text_encoders: List[CLIPTextModel],
|
||||||
tokenizers: List[CLIPTokenizer],
|
tokenizers: List[CLIPTokenizer],
|
||||||
unet: SdxlUNet2DConditionModel,
|
unet: InferSdxlUNet2DConditionModel,
|
||||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||||
clip_skip: int,
|
clip_skip: int,
|
||||||
):
|
):
|
||||||
@@ -328,7 +328,7 @@ class PipelineLike:
|
|||||||
self.vae = vae
|
self.vae = vae
|
||||||
self.text_encoders = text_encoders
|
self.text_encoders = text_encoders
|
||||||
self.tokenizers = tokenizers
|
self.tokenizers = tokenizers
|
||||||
self.unet: SdxlUNet2DConditionModel = unet
|
self.unet: InferSdxlUNet2DConditionModel = unet
|
||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
self.safety_checker = None
|
self.safety_checker = None
|
||||||
|
|
||||||
@@ -1371,6 +1371,7 @@ def main(args):
|
|||||||
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
|
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
|
||||||
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
|
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
|
||||||
)
|
)
|
||||||
|
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
|
||||||
|
|
||||||
# xformers、Hypernetwork対応
|
# xformers、Hypernetwork対応
|
||||||
if not args.diffusers_xformers:
|
if not args.diffusers_xformers:
|
||||||
@@ -1526,10 +1527,14 @@ def main(args):
|
|||||||
print("set vae_dtype to float32")
|
print("set vae_dtype to float32")
|
||||||
vae_dtype = torch.float32
|
vae_dtype = torch.float32
|
||||||
vae.to(vae_dtype).to(device)
|
vae.to(vae_dtype).to(device)
|
||||||
|
vae.eval()
|
||||||
|
|
||||||
text_encoder1.to(dtype).to(device)
|
text_encoder1.to(dtype).to(device)
|
||||||
text_encoder2.to(dtype).to(device)
|
text_encoder2.to(dtype).to(device)
|
||||||
unet.to(dtype).to(device)
|
unet.to(dtype).to(device)
|
||||||
|
text_encoder1.eval()
|
||||||
|
text_encoder2.eval()
|
||||||
|
unet.eval()
|
||||||
|
|
||||||
# networkを組み込む
|
# networkを組み込む
|
||||||
if args.network_module:
|
if args.network_module:
|
||||||
|
|||||||
Reference in New Issue
Block a user