mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix mem_eff_attn does not work
This commit is contained in:
@@ -63,6 +63,7 @@ import safetensors.torch
|
||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||
import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
|
||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
@@ -1787,100 +1788,18 @@ class FlashAttentionFunction(torch.autograd.function.Function):
|
||||
return dq, dk, dv, None, None, None, None
|
||||
|
||||
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
# unet is not used currently, but it is here for future use
|
||||
def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
if mem_eff_attn:
|
||||
replace_unet_cross_attn_to_memory_efficient()
|
||||
print("Enable memory efficient attention for U-Net")
|
||||
unet.set_use_memory_efficient_attention(False, True)
|
||||
elif xformers:
|
||||
replace_unet_cross_attn_to_xformers(unet)
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_memory_efficient():
|
||||
print("CrossAttention.forward has been replaced to FlashAttention (not xformers)")
|
||||
flash_func = FlashAttentionFunction
|
||||
|
||||
def forward_flash_attn(self, x, context=None, mask=None):
|
||||
q_bucket_size = 512
|
||||
k_bucket_size = 1024
|
||||
|
||||
h = self.heads
|
||||
q = self.to_q(x)
|
||||
|
||||
context = context if context is not None else x
|
||||
context = context.to(x.dtype)
|
||||
|
||||
if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
||||
context_k, context_v = self.hypernetwork.forward(x, context)
|
||||
context_k = context_k.to(x.dtype)
|
||||
context_v = context_v.to(x.dtype)
|
||||
else:
|
||||
context_k = context
|
||||
context_v = context
|
||||
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
||||
|
||||
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
|
||||
out = self.to_out[0](out)
|
||||
# out = self.to_out[1](out)
|
||||
return out
|
||||
|
||||
# diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
||||
from library.original_unet import CrossAttention
|
||||
|
||||
CrossAttention.forward = forward_flash_attn
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_xformers(unet):
|
||||
print("CrossAttention.forward has been replaced to enable xformers.")
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||
|
||||
unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
# def forward_xformers(self, x, context=None, mask=None):
|
||||
# h = self.heads
|
||||
# q_in = self.to_q(x)
|
||||
|
||||
# context = default(context, x)
|
||||
# context = context.to(x.dtype)
|
||||
|
||||
# if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
||||
# context_k, context_v = self.hypernetwork.forward(x, context)
|
||||
# context_k = context_k.to(x.dtype)
|
||||
# context_v = context_v.to(x.dtype)
|
||||
# else:
|
||||
# context_k = context
|
||||
# context_v = context
|
||||
|
||||
# k_in = self.to_k(context_k)
|
||||
# v_in = self.to_v(context_v)
|
||||
|
||||
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
||||
# del q_in, k_in, v_in
|
||||
|
||||
# q = q.contiguous()
|
||||
# k = k.contiguous()
|
||||
# v = v.contiguous()
|
||||
# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
||||
|
||||
# out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
||||
|
||||
# # diffusers 0.7.0~
|
||||
# out = self.to_out[0](out)
|
||||
# out = self.to_out[1](out)
|
||||
# return out
|
||||
|
||||
# diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||
print("Enable xformers for U-Net")
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||
|
||||
unet.set_use_memory_efficient_attention(True, False)
|
||||
|
||||
"""
|
||||
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
||||
|
||||
Reference in New Issue
Block a user