mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
use xformers in VAE in gen script
This commit is contained in:
@@ -311,6 +311,7 @@ class FlashAttentionFunction(torch.autograd.Function):
|
|||||||
return dq, dk, dv, None, None, None, None
|
return dq, dk, dv, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
# TODO common train_util.py
|
||||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||||
if mem_eff_attn:
|
if mem_eff_attn:
|
||||||
replace_unet_cross_attn_to_memory_efficient()
|
replace_unet_cross_attn_to_memory_efficient()
|
||||||
@@ -319,7 +320,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
|
|||||||
|
|
||||||
|
|
||||||
def replace_unet_cross_attn_to_memory_efficient():
|
def replace_unet_cross_attn_to_memory_efficient():
|
||||||
print("Replace CrossAttention.forward to use NAI style Hypernetwork and FlashAttention")
|
print("CrossAttention.forward has been replaced to FlashAttention (not xformers) and NAI style Hypernetwork")
|
||||||
flash_func = FlashAttentionFunction
|
flash_func = FlashAttentionFunction
|
||||||
|
|
||||||
def forward_flash_attn(self, x, context=None, mask=None):
|
def forward_flash_attn(self, x, context=None, mask=None):
|
||||||
@@ -359,7 +360,7 @@ def replace_unet_cross_attn_to_memory_efficient():
|
|||||||
|
|
||||||
|
|
||||||
def replace_unet_cross_attn_to_xformers():
|
def replace_unet_cross_attn_to_xformers():
|
||||||
print("Replace CrossAttention.forward to use NAI style Hypernetwork and xformers")
|
print("CrossAttention.forward has been replaced to enable xformers and NAI style Hypernetwork")
|
||||||
try:
|
try:
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -401,6 +402,104 @@ def replace_unet_cross_attn_to_xformers():
|
|||||||
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||||
|
|
||||||
|
|
||||||
|
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
||||||
|
if mem_eff_attn:
|
||||||
|
replace_vae_attn_to_memory_efficient()
|
||||||
|
elif xformers:
|
||||||
|
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
|
||||||
|
print("Use Diffusers xformers for VAE")
|
||||||
|
vae.set_use_memory_efficient_attention_xformers(True)
|
||||||
|
|
||||||
|
"""
|
||||||
|
# VAEがbfloat16でメモリ消費が大きい問題を解決する
|
||||||
|
upsamplers = []
|
||||||
|
for block in vae.decoder.up_blocks:
|
||||||
|
if block.upsamplers is not None:
|
||||||
|
upsamplers.extend(block.upsamplers)
|
||||||
|
|
||||||
|
def forward_upsample(_self, hidden_states, output_size=None):
|
||||||
|
assert hidden_states.shape[1] == _self.channels
|
||||||
|
if _self.use_conv_transpose:
|
||||||
|
return _self.conv(hidden_states)
|
||||||
|
|
||||||
|
dtype = hidden_states.dtype
|
||||||
|
if dtype == torch.bfloat16:
|
||||||
|
assert output_size is None
|
||||||
|
# repeat_interleaveはすごく遅いが、回数はあまり呼ばれないので許容する
|
||||||
|
hidden_states = hidden_states.repeat_interleave(2, dim=-1)
|
||||||
|
hidden_states = hidden_states.repeat_interleave(2, dim=-2)
|
||||||
|
else:
|
||||||
|
if hidden_states.shape[0] >= 64:
|
||||||
|
hidden_states = hidden_states.contiguous()
|
||||||
|
|
||||||
|
# if `output_size` is passed we force the interpolation output
|
||||||
|
# size and do not make use of `scale_factor=2`
|
||||||
|
if output_size is None:
|
||||||
|
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||||
|
else:
|
||||||
|
hidden_states = torch.nn.functional.interpolate(hidden_states, size=output_size, mode="nearest")
|
||||||
|
|
||||||
|
if _self.use_conv:
|
||||||
|
if _self.name == "conv":
|
||||||
|
hidden_states = _self.conv(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states = _self.Conv2d_0(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
# replace upsamplers
|
||||||
|
for upsampler in upsamplers:
|
||||||
|
# make new scope
|
||||||
|
def make_replacer(upsampler):
|
||||||
|
def forward(hidden_states, output_size=None):
|
||||||
|
return forward_upsample(upsampler, hidden_states, output_size)
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
upsampler.forward = make_replacer(upsampler)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def replace_vae_attn_to_memory_efficient():
|
||||||
|
print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)")
|
||||||
|
flash_func = FlashAttentionFunction
|
||||||
|
|
||||||
|
def forward_flash_attn(self, hidden_states):
|
||||||
|
print("forward_flash_attn")
|
||||||
|
q_bucket_size = 512
|
||||||
|
k_bucket_size = 1024
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
batch, channel, height, width = hidden_states.shape
|
||||||
|
|
||||||
|
# norm
|
||||||
|
hidden_states = self.group_norm(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
# proj to q, k, v
|
||||||
|
query_proj = self.query(hidden_states)
|
||||||
|
key_proj = self.key(hidden_states)
|
||||||
|
value_proj = self.value(hidden_states)
|
||||||
|
|
||||||
|
query_proj, key_proj, value_proj = map(
|
||||||
|
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
|
||||||
|
)
|
||||||
|
|
||||||
|
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
|
||||||
|
|
||||||
|
out = rearrange(out, "b h n d -> b n (h d)")
|
||||||
|
|
||||||
|
# compute next hidden_states
|
||||||
|
hidden_states = self.proj_attn(hidden_states)
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||||
|
|
||||||
|
# res connect and rescale
|
||||||
|
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正
|
# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正
|
||||||
@@ -2142,6 +2241,7 @@ def main(args):
|
|||||||
# xformers、Hypernetwork対応
|
# xformers、Hypernetwork対応
|
||||||
if not args.diffusers_xformers:
|
if not args.diffusers_xformers:
|
||||||
replace_unet_modules(unet, not args.xformers, args.xformers)
|
replace_unet_modules(unet, not args.xformers, args.xformers)
|
||||||
|
replace_vae_modules(vae, not args.xformers, args.xformers)
|
||||||
|
|
||||||
# tokenizerを読み込む
|
# tokenizerを読み込む
|
||||||
print("loading tokenizer")
|
print("loading tokenizer")
|
||||||
@@ -3175,8 +3275,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
"--vae_slices",
|
"--vae_slices",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help=
|
help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨",
|
||||||
"number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨"
|
|
||||||
)
|
)
|
||||||
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -1765,6 +1765,7 @@ class FlashAttentionFunction(torch.autograd.function.Function):
|
|||||||
|
|
||||||
|
|
||||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||||
|
# unet is not used currently, but it is here for future use
|
||||||
if mem_eff_attn:
|
if mem_eff_attn:
|
||||||
replace_unet_cross_attn_to_memory_efficient()
|
replace_unet_cross_attn_to_memory_efficient()
|
||||||
elif xformers:
|
elif xformers:
|
||||||
@@ -1772,7 +1773,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
|
|||||||
|
|
||||||
|
|
||||||
def replace_unet_cross_attn_to_memory_efficient():
|
def replace_unet_cross_attn_to_memory_efficient():
|
||||||
print("Replace CrossAttention.forward to use FlashAttention (not xformers)")
|
print("CrossAttention.forward has been replaced to FlashAttention (not xformers)")
|
||||||
flash_func = FlashAttentionFunction
|
flash_func = FlashAttentionFunction
|
||||||
|
|
||||||
def forward_flash_attn(self, x, context=None, mask=None):
|
def forward_flash_attn(self, x, context=None, mask=None):
|
||||||
@@ -1812,7 +1813,7 @@ def replace_unet_cross_attn_to_memory_efficient():
|
|||||||
|
|
||||||
|
|
||||||
def replace_unet_cross_attn_to_xformers():
|
def replace_unet_cross_attn_to_xformers():
|
||||||
print("Replace CrossAttention.forward to use xformers")
|
print("CrossAttention.forward has been replaced to enable xformers.")
|
||||||
try:
|
try:
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -1854,6 +1855,60 @@ def replace_unet_cross_attn_to_xformers():
|
|||||||
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
||||||
|
# vae is not used currently, but it is here for future use
|
||||||
|
if mem_eff_attn:
|
||||||
|
replace_vae_attn_to_memory_efficient()
|
||||||
|
elif xformers:
|
||||||
|
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
|
||||||
|
print("Use Diffusers xformers for VAE")
|
||||||
|
vae.encoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True)
|
||||||
|
vae.decoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_vae_attn_to_memory_efficient():
|
||||||
|
print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)")
|
||||||
|
flash_func = FlashAttentionFunction
|
||||||
|
|
||||||
|
def forward_flash_attn(self, hidden_states):
|
||||||
|
print("forward_flash_attn")
|
||||||
|
q_bucket_size = 512
|
||||||
|
k_bucket_size = 1024
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
batch, channel, height, width = hidden_states.shape
|
||||||
|
|
||||||
|
# norm
|
||||||
|
hidden_states = self.group_norm(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
# proj to q, k, v
|
||||||
|
query_proj = self.query(hidden_states)
|
||||||
|
key_proj = self.key(hidden_states)
|
||||||
|
value_proj = self.value(hidden_states)
|
||||||
|
|
||||||
|
query_proj, key_proj, value_proj = map(
|
||||||
|
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
|
||||||
|
)
|
||||||
|
|
||||||
|
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
|
||||||
|
|
||||||
|
out = rearrange(out, "b h n d -> b n (h d)")
|
||||||
|
|
||||||
|
# compute next hidden_states
|
||||||
|
hidden_states = self.proj_attn(hidden_states)
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||||
|
|
||||||
|
# res connect and rescale
|
||||||
|
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user