support sdpa

This commit is contained in:
ykume
2023-06-11 21:26:15 +09:00
parent 4d0c06e397
commit 9e1683cf2b
9 changed files with 177 additions and 84 deletions

View File

@@ -137,7 +137,7 @@ USE_CUTOUTS = False
"""
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
print("Enable memory efficient attention for U-Net")
@@ -151,56 +151,26 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
raise ImportError("No xformers / xformersがインストールされていないようです")
unet.set_use_memory_efficient_attention(True, False)
elif sdpa:
print("Enable SDPA for U-Net")
unet.set_use_memory_efficient_attention(False, False)
unet.set_use_sdpa(True)
# TODO common train_util.py
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
replace_vae_attn_to_memory_efficient()
elif xformers:
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
replace_vae_attn_to_xformers()
elif sdpa:
replace_vae_attn_to_sdpa()
def replace_vae_attn_to_memory_efficient():
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
flash_func = FlashAttentionFunction
def forward_flash_attn_0_14(self, hidden_states, **kwargs):
q_bucket_size = 512
k_bucket_size = 1024
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
)
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
out = rearrange(out, "b h n d -> b n (h d)")
# compute next hidden_states
# linear proj
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_flash_attn(self, hidden_states, **kwargs):
q_bucket_size = 512
k_bucket_size = 1024
@@ -238,6 +208,15 @@ def replace_vae_attn_to_memory_efficient():
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_flash_attn_0_14(self, hidden_states, **kwargs):
if not hasattr(self, "to_q"):
self.to_q = self.query
self.to_k = self.key
self.to_v = self.value
self.to_out = [self.proj_attn, torch.nn.Identity()]
self.heads = self.num_heads
return forward_flash_attn(self, hidden_states, **kwargs)
if diffusers.__version__ < "0.15.0":
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14
else:
@@ -248,40 +227,6 @@ def replace_vae_attn_to_xformers():
print("VAE: Attention.forward has been replaced to xformers")
import xformers.ops
def forward_xformers_0_14(self, hidden_states, **kwargs):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
)
query_proj = query_proj.contiguous()
key_proj = key_proj.contiguous()
value_proj = value_proj.contiguous()
out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
out = rearrange(out, "b h n d -> b n (h d)")
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_xformers(self, hidden_states, **kwargs):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
@@ -319,12 +264,75 @@ def replace_vae_attn_to_xformers():
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_xformers_0_14(self, hidden_states, **kwargs):
if not hasattr(self, "to_q"):
self.to_q = self.query
self.to_k = self.key
self.to_v = self.value
self.to_out = [self.proj_attn, torch.nn.Identity()]
self.heads = self.num_heads
return forward_xformers(self, hidden_states, **kwargs)
if diffusers.__version__ < "0.15.0":
diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14
else:
diffusers.models.attention_processor.Attention.forward = forward_xformers
def replace_vae_attn_to_sdpa():
print("VAE: Attention.forward has been replaced to sdpa")
def forward_sdpa(self, hidden_states, **kwargs):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.to_q(hidden_states)
key_proj = self.to_k(hidden_states)
value_proj = self.to_v(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj)
)
out = torch.nn.functional.scaled_dot_product_attention(
query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False
)
out = rearrange(out, "b n h d -> b n (h d)")
# compute next hidden_states
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_sdpa_0_14(self, hidden_states, **kwargs):
if not hasattr(self, "to_q"):
self.to_q = self.query
self.to_k = self.key
self.to_v = self.value
self.to_out = [self.proj_attn, torch.nn.Identity()]
self.heads = self.num_heads
return forward_sdpa(self, hidden_states, **kwargs)
if diffusers.__version__ < "0.15.0":
diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14
else:
diffusers.models.attention_processor.Attention.forward = forward_sdpa
# endregion
# region 画像生成の本体lpw_stable_diffusion.py ASLからコピーして修正
@@ -2082,8 +2090,9 @@ def main(args):
# xformers、Hypernetwork対応
if not args.diffusers_xformers:
replace_unet_modules(unet, not args.xformers, args.xformers)
replace_vae_modules(vae, not args.xformers, args.xformers)
mem_eff = not (args.xformers or args.sdpa)
replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa)
replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa)
# tokenizerを読み込む
print("loading tokenizer")
@@ -3176,6 +3185,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する")
parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する")
parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する")
parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa")
parser.add_argument(
"--diffusers_xformers",
action="store_true",