From 035dd3a900ce9208c5d703c354ec839166b8f9cc Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 11 Jun 2023 17:08:21 +0900 Subject: [PATCH] fix mem_eff_attn does not work --- library/original_unet.py | 10 ++-- library/train_util.py | 101 ++++----------------------------------- 2 files changed, 15 insertions(+), 96 deletions(-) diff --git a/library/original_unet.py b/library/original_unet.py index 0e64280b..36318eb9 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -278,7 +278,7 @@ class FlashAttentionFunction(torch.autograd.Function): for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): k_start_index = k_ind * k_bucket_size - attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale + attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale if causal and q_start_index < (k_start_index + k_bucket_size - 1): causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( @@ -293,14 +293,14 @@ class FlashAttentionFunction(torch.autograd.Function): p = exp_attn_weights / lc - dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) - dp = einsum("... i d, ... j d -> ... i j", doc, vc) + dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) + dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) D = (doc * oc).sum(dim=-1, keepdims=True) ds = p * scale * (dp - D) - dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) - dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) + dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) dqc.add_(dq_chunk) dkc.add_(dk_chunk) diff --git a/library/train_util.py b/library/train_util.py index 8aa7f987..7d7eb325 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -63,6 +63,7 @@ import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline import library.model_util as model_util import library.huggingface_util as huggingface_util +from library.original_unet import UNet2DConditionModel # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" @@ -1787,100 +1788,18 @@ class FlashAttentionFunction(torch.autograd.function.Function): return dq, dk, dv, None, None, None, None -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): - # unet is not used currently, but it is here for future use +def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers): if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() + print("Enable memory efficient attention for U-Net") + unet.set_use_memory_efficient_attention(False, True) elif xformers: - replace_unet_cross_attn_to_xformers(unet) - - -def replace_unet_cross_attn_to_memory_efficient(): - print("CrossAttention.forward has been replaced to FlashAttention (not xformers)") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 - - h = self.heads - q = self.to_q(x) - - context = context if context is not None else x - context = context.to(x.dtype) - - if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context - - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, "b h n d -> b n (h d)") - - out = self.to_out[0](out) - # out = self.to_out[1](out) - return out - - # diffusers.models.attention.CrossAttention.forward = forward_flash_attn - from library.original_unet import CrossAttention - - CrossAttention.forward = forward_flash_attn - - -def replace_unet_cross_attn_to_xformers(unet): - print("CrossAttention.forward has been replaced to enable xformers.") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") - - unet.set_use_memory_efficient_attention_xformers(True) - - # def forward_xformers(self, x, context=None, mask=None): - # h = self.heads - # q_in = self.to_q(x) - - # context = default(context, x) - # context = context.to(x.dtype) - - # if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - # context_k, context_v = self.hypernetwork.forward(x, context) - # context_k = context_k.to(x.dtype) - # context_v = context_v.to(x.dtype) - # else: - # context_k = context - # context_v = context - - # k_in = self.to_k(context_k) - # v_in = self.to_v(context_v) - - # q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) - # del q_in, k_in, v_in - - # q = q.contiguous() - # k = k.contiguous() - # v = v.contiguous() - # out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - - # out = rearrange(out, "b n h d -> b n (h d)", h=h) - - # # diffusers 0.7.0~ - # out = self.to_out[0](out) - # out = self.to_out[1](out) - # return out - - # diffusers.models.attention.CrossAttention.forward = forward_xformers + print("Enable xformers for U-Net") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") + unet.set_use_memory_efficient_attention(True, False) """ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):