mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
use xformers in VAE in gen script
This commit is contained in:
@@ -311,6 +311,7 @@ class FlashAttentionFunction(torch.autograd.Function):
|
||||
return dq, dk, dv, None, None, None, None
|
||||
|
||||
|
||||
# TODO common train_util.py
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
if mem_eff_attn:
|
||||
replace_unet_cross_attn_to_memory_efficient()
|
||||
@@ -319,7 +320,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_memory_efficient():
|
||||
print("Replace CrossAttention.forward to use NAI style Hypernetwork and FlashAttention")
|
||||
print("CrossAttention.forward has been replaced to FlashAttention (not xformers) and NAI style Hypernetwork")
|
||||
flash_func = FlashAttentionFunction
|
||||
|
||||
def forward_flash_attn(self, x, context=None, mask=None):
|
||||
@@ -359,7 +360,7 @@ def replace_unet_cross_attn_to_memory_efficient():
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_xformers():
|
||||
print("Replace CrossAttention.forward to use NAI style Hypernetwork and xformers")
|
||||
print("CrossAttention.forward has been replaced to enable xformers and NAI style Hypernetwork")
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
@@ -401,6 +402,104 @@ def replace_unet_cross_attn_to_xformers():
|
||||
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||
|
||||
|
||||
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
||||
if mem_eff_attn:
|
||||
replace_vae_attn_to_memory_efficient()
|
||||
elif xformers:
|
||||
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
|
||||
print("Use Diffusers xformers for VAE")
|
||||
vae.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
"""
|
||||
# VAEがbfloat16でメモリ消費が大きい問題を解決する
|
||||
upsamplers = []
|
||||
for block in vae.decoder.up_blocks:
|
||||
if block.upsamplers is not None:
|
||||
upsamplers.extend(block.upsamplers)
|
||||
|
||||
def forward_upsample(_self, hidden_states, output_size=None):
|
||||
assert hidden_states.shape[1] == _self.channels
|
||||
if _self.use_conv_transpose:
|
||||
return _self.conv(hidden_states)
|
||||
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
assert output_size is None
|
||||
# repeat_interleaveはすごく遅いが、回数はあまり呼ばれないので許容する
|
||||
hidden_states = hidden_states.repeat_interleave(2, dim=-1)
|
||||
hidden_states = hidden_states.repeat_interleave(2, dim=-2)
|
||||
else:
|
||||
if hidden_states.shape[0] >= 64:
|
||||
hidden_states = hidden_states.contiguous()
|
||||
|
||||
# if `output_size` is passed we force the interpolation output
|
||||
# size and do not make use of `scale_factor=2`
|
||||
if output_size is None:
|
||||
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
else:
|
||||
hidden_states = torch.nn.functional.interpolate(hidden_states, size=output_size, mode="nearest")
|
||||
|
||||
if _self.use_conv:
|
||||
if _self.name == "conv":
|
||||
hidden_states = _self.conv(hidden_states)
|
||||
else:
|
||||
hidden_states = _self.Conv2d_0(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
# replace upsamplers
|
||||
for upsampler in upsamplers:
|
||||
# make new scope
|
||||
def make_replacer(upsampler):
|
||||
def forward(hidden_states, output_size=None):
|
||||
return forward_upsample(upsampler, hidden_states, output_size)
|
||||
|
||||
return forward
|
||||
|
||||
upsampler.forward = make_replacer(upsampler)
|
||||
"""
|
||||
|
||||
|
||||
def replace_vae_attn_to_memory_efficient():
|
||||
print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)")
|
||||
flash_func = FlashAttentionFunction
|
||||
|
||||
def forward_flash_attn(self, hidden_states):
|
||||
print("forward_flash_attn")
|
||||
q_bucket_size = 512
|
||||
k_bucket_size = 1024
|
||||
|
||||
residual = hidden_states
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
|
||||
# norm
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
||||
|
||||
# proj to q, k, v
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
|
||||
query_proj, key_proj, value_proj = map(
|
||||
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
|
||||
)
|
||||
|
||||
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||
|
||||
# res connect and rescale
|
||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正
|
||||
@@ -2142,6 +2241,7 @@ def main(args):
|
||||
# xformers、Hypernetwork対応
|
||||
if not args.diffusers_xformers:
|
||||
replace_unet_modules(unet, not args.xformers, args.xformers)
|
||||
replace_vae_modules(vae, not args.xformers, args.xformers)
|
||||
|
||||
# tokenizerを読み込む
|
||||
print("loading tokenizer")
|
||||
@@ -3175,8 +3275,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--vae_slices",
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
"number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨"
|
||||
help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨",
|
||||
)
|
||||
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user