mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix mem_eff_attn does not work
This commit is contained in:
@@ -278,7 +278,7 @@ class FlashAttentionFunction(torch.autograd.Function):
|
|||||||
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
||||||
k_start_index = k_ind * k_bucket_size
|
k_start_index = k_ind * k_bucket_size
|
||||||
|
|
||||||
attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
||||||
|
|
||||||
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
||||||
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
||||||
@@ -293,14 +293,14 @@ class FlashAttentionFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
p = exp_attn_weights / lc
|
p = exp_attn_weights / lc
|
||||||
|
|
||||||
dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc)
|
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
|
||||||
dp = einsum("... i d, ... j d -> ... i j", doc, vc)
|
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
|
||||||
|
|
||||||
D = (doc * oc).sum(dim=-1, keepdims=True)
|
D = (doc * oc).sum(dim=-1, keepdims=True)
|
||||||
ds = p * scale * (dp - D)
|
ds = p * scale * (dp - D)
|
||||||
|
|
||||||
dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc)
|
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
|
||||||
dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc)
|
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
|
||||||
|
|
||||||
dqc.add_(dq_chunk)
|
dqc.add_(dq_chunk)
|
||||||
dkc.add_(dk_chunk)
|
dkc.add_(dk_chunk)
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ import safetensors.torch
|
|||||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import library.huggingface_util as huggingface_util
|
import library.huggingface_util as huggingface_util
|
||||||
|
from library.original_unet import UNet2DConditionModel
|
||||||
|
|
||||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||||
@@ -1787,100 +1788,18 @@ class FlashAttentionFunction(torch.autograd.function.Function):
|
|||||||
return dq, dk, dv, None, None, None, None
|
return dq, dk, dv, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers):
|
||||||
# unet is not used currently, but it is here for future use
|
|
||||||
if mem_eff_attn:
|
if mem_eff_attn:
|
||||||
replace_unet_cross_attn_to_memory_efficient()
|
print("Enable memory efficient attention for U-Net")
|
||||||
|
unet.set_use_memory_efficient_attention(False, True)
|
||||||
elif xformers:
|
elif xformers:
|
||||||
replace_unet_cross_attn_to_xformers(unet)
|
print("Enable xformers for U-Net")
|
||||||
|
try:
|
||||||
|
import xformers.ops
|
||||||
def replace_unet_cross_attn_to_memory_efficient():
|
except ImportError:
|
||||||
print("CrossAttention.forward has been replaced to FlashAttention (not xformers)")
|
raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||||
flash_func = FlashAttentionFunction
|
|
||||||
|
|
||||||
def forward_flash_attn(self, x, context=None, mask=None):
|
|
||||||
q_bucket_size = 512
|
|
||||||
k_bucket_size = 1024
|
|
||||||
|
|
||||||
h = self.heads
|
|
||||||
q = self.to_q(x)
|
|
||||||
|
|
||||||
context = context if context is not None else x
|
|
||||||
context = context.to(x.dtype)
|
|
||||||
|
|
||||||
if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
|
||||||
context_k, context_v = self.hypernetwork.forward(x, context)
|
|
||||||
context_k = context_k.to(x.dtype)
|
|
||||||
context_v = context_v.to(x.dtype)
|
|
||||||
else:
|
|
||||||
context_k = context
|
|
||||||
context_v = context
|
|
||||||
|
|
||||||
k = self.to_k(context_k)
|
|
||||||
v = self.to_v(context_v)
|
|
||||||
del context, x
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
|
||||||
|
|
||||||
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
|
|
||||||
|
|
||||||
out = rearrange(out, "b h n d -> b n (h d)")
|
|
||||||
|
|
||||||
out = self.to_out[0](out)
|
|
||||||
# out = self.to_out[1](out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
# diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
|
||||||
from library.original_unet import CrossAttention
|
|
||||||
|
|
||||||
CrossAttention.forward = forward_flash_attn
|
|
||||||
|
|
||||||
|
|
||||||
def replace_unet_cross_attn_to_xformers(unet):
|
|
||||||
print("CrossAttention.forward has been replaced to enable xformers.")
|
|
||||||
try:
|
|
||||||
import xformers.ops
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("No xformers / xformersがインストールされていないようです")
|
|
||||||
|
|
||||||
unet.set_use_memory_efficient_attention_xformers(True)
|
|
||||||
|
|
||||||
# def forward_xformers(self, x, context=None, mask=None):
|
|
||||||
# h = self.heads
|
|
||||||
# q_in = self.to_q(x)
|
|
||||||
|
|
||||||
# context = default(context, x)
|
|
||||||
# context = context.to(x.dtype)
|
|
||||||
|
|
||||||
# if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
|
||||||
# context_k, context_v = self.hypernetwork.forward(x, context)
|
|
||||||
# context_k = context_k.to(x.dtype)
|
|
||||||
# context_v = context_v.to(x.dtype)
|
|
||||||
# else:
|
|
||||||
# context_k = context
|
|
||||||
# context_v = context
|
|
||||||
|
|
||||||
# k_in = self.to_k(context_k)
|
|
||||||
# v_in = self.to_v(context_v)
|
|
||||||
|
|
||||||
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
|
||||||
# del q_in, k_in, v_in
|
|
||||||
|
|
||||||
# q = q.contiguous()
|
|
||||||
# k = k.contiguous()
|
|
||||||
# v = v.contiguous()
|
|
||||||
# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
|
||||||
|
|
||||||
# out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
|
||||||
|
|
||||||
# # diffusers 0.7.0~
|
|
||||||
# out = self.to_out[0](out)
|
|
||||||
# out = self.to_out[1](out)
|
|
||||||
# return out
|
|
||||||
|
|
||||||
# diffusers.models.attention.CrossAttention.forward = forward_xformers
|
|
||||||
|
|
||||||
|
unet.set_use_memory_efficient_attention(True, False)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
||||||
|
|||||||
Reference in New Issue
Block a user