fix eps value, enable xformers, etc.

This commit is contained in:
Kohya S
2023-06-03 21:29:27 +09:00
parent 5db792b10b
commit c0a7df9ee1
3 changed files with 171 additions and 78 deletions

View File

@@ -1792,7 +1792,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
if mem_eff_attn:
replace_unet_cross_attn_to_memory_efficient()
elif xformers:
replace_unet_cross_attn_to_xformers()
replace_unet_cross_attn_to_xformers(unet)
def replace_unet_cross_attn_to_memory_efficient():
@@ -1827,55 +1827,59 @@ def replace_unet_cross_attn_to_memory_efficient():
out = rearrange(out, "b h n d -> b n (h d)")
# diffusers 0.7.0~ わざわざ変えるなよ (;´Д`)
out = self.to_out[0](out)
out = self.to_out[1](out)
# out = self.to_out[1](out)
return out
diffusers.models.attention.CrossAttention.forward = forward_flash_attn
# diffusers.models.attention.CrossAttention.forward = forward_flash_attn
from library.original_unet import CrossAttention
CrossAttention.forward = forward_flash_attn
def replace_unet_cross_attn_to_xformers():
def replace_unet_cross_attn_to_xformers(unet):
print("CrossAttention.forward has been replaced to enable xformers.")
try:
import xformers.ops
except ImportError:
raise ImportError("No xformers / xformersがインストールされていないようです")
def forward_xformers(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
unet.set_use_memory_efficient_attention_xformers(True)
context = default(context, x)
context = context.to(x.dtype)
# def forward_xformers(self, x, context=None, mask=None):
# h = self.heads
# q_in = self.to_q(x)
if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
context_k, context_v = self.hypernetwork.forward(x, context)
context_k = context_k.to(x.dtype)
context_v = context_v.to(x.dtype)
else:
context_k = context
context_v = context
# context = default(context, x)
# context = context.to(x.dtype)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
# if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
# context_k, context_v = self.hypernetwork.forward(x, context)
# context_k = context_k.to(x.dtype)
# context_v = context_v.to(x.dtype)
# else:
# context_k = context
# context_v = context
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
# k_in = self.to_k(context_k)
# v_in = self.to_v(context_v)
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
# del q_in, k_in, v_in
out = rearrange(out, "b n h d -> b n (h d)", h=h)
# q = q.contiguous()
# k = k.contiguous()
# v = v.contiguous()
# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
# diffusers 0.7.0~
out = self.to_out[0](out)
out = self.to_out[1](out)
return out
# out = rearrange(out, "b n h d -> b n (h d)", h=h)
diffusers.models.attention.CrossAttention.forward = forward_xformers
# # diffusers 0.7.0~
# out = self.to_out[0](out)
# out = self.to_out[1](out)
# return out
# diffusers.models.attention.CrossAttention.forward = forward_xformers
"""