fix mem_eff_attn does not work

This commit is contained in:
ykume
2023-06-11 17:08:21 +09:00
parent 4e25c8f78e
commit 035dd3a900
2 changed files with 15 additions and 96 deletions

View File

@@ -278,7 +278,7 @@ class FlashAttentionFunction(torch.autograd.Function):
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size k_start_index = k_ind * k_bucket_size
attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
if causal and q_start_index < (k_start_index + k_bucket_size - 1): if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
@@ -293,14 +293,14 @@ class FlashAttentionFunction(torch.autograd.Function):
p = exp_attn_weights / lc p = exp_attn_weights / lc
dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
dp = einsum("... i d, ... j d -> ... i j", doc, vc) dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
D = (doc * oc).sum(dim=-1, keepdims=True) D = (doc * oc).sum(dim=-1, keepdims=True)
ds = p * scale * (dp - D) ds = p * scale * (dp - D)
dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
dqc.add_(dq_chunk) dqc.add_(dq_chunk)
dkc.add_(dk_chunk) dkc.add_(dk_chunk)

View File

@@ -63,6 +63,7 @@ import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import library.model_util as model_util import library.model_util as model_util
import library.huggingface_util as huggingface_util import library.huggingface_util as huggingface_util
from library.original_unet import UNet2DConditionModel
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14" TOKENIZER_PATH = "openai/clip-vit-large-patch14"
@@ -1787,100 +1788,18 @@ class FlashAttentionFunction(torch.autograd.function.Function):
return dq, dk, dv, None, None, None, None return dq, dk, dv, None, None, None, None
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers):
# unet is not used currently, but it is here for future use
if mem_eff_attn: if mem_eff_attn:
replace_unet_cross_attn_to_memory_efficient() print("Enable memory efficient attention for U-Net")
unet.set_use_memory_efficient_attention(False, True)
elif xformers: elif xformers:
replace_unet_cross_attn_to_xformers(unet) print("Enable xformers for U-Net")
def replace_unet_cross_attn_to_memory_efficient():
print("CrossAttention.forward has been replaced to FlashAttention (not xformers)")
flash_func = FlashAttentionFunction
def forward_flash_attn(self, x, context=None, mask=None):
q_bucket_size = 512
k_bucket_size = 1024
h = self.heads
q = self.to_q(x)
context = context if context is not None else x
context = context.to(x.dtype)
if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
context_k, context_v = self.hypernetwork.forward(x, context)
context_k = context_k.to(x.dtype)
context_v = context_v.to(x.dtype)
else:
context_k = context
context_v = context
k = self.to_k(context_k)
v = self.to_v(context_v)
del context, x
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.to_out[0](out)
# out = self.to_out[1](out)
return out
# diffusers.models.attention.CrossAttention.forward = forward_flash_attn
from library.original_unet import CrossAttention
CrossAttention.forward = forward_flash_attn
def replace_unet_cross_attn_to_xformers(unet):
print("CrossAttention.forward has been replaced to enable xformers.")
try: try:
import xformers.ops import xformers.ops
except ImportError: except ImportError:
raise ImportError("No xformers / xformersがインストールされていないようです") raise ImportError("No xformers / xformersがインストールされていないようです")
unet.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention(True, False)
# def forward_xformers(self, x, context=None, mask=None):
# h = self.heads
# q_in = self.to_q(x)
# context = default(context, x)
# context = context.to(x.dtype)
# if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
# context_k, context_v = self.hypernetwork.forward(x, context)
# context_k = context_k.to(x.dtype)
# context_v = context_v.to(x.dtype)
# else:
# context_k = context
# context_v = context
# k_in = self.to_k(context_k)
# v_in = self.to_v(context_v)
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
# del q_in, k_in, v_in
# q = q.contiguous()
# k = k.contiguous()
# v = v.contiguous()
# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
# out = rearrange(out, "b n h d -> b n (h d)", h=h)
# # diffusers 0.7.0~
# out = self.to_out[0](out)
# out = self.to_out[1](out)
# return out
# diffusers.models.attention.CrossAttention.forward = forward_xformers
""" """
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):