fix eps value, enable xformers, etc.

This commit is contained in:
Kohya S
2023-06-03 21:29:27 +09:00
parent 5db792b10b
commit c0a7df9ee1
3 changed files with 171 additions and 78 deletions

View File

@@ -317,7 +317,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
if mem_eff_attn:
replace_unet_cross_attn_to_memory_efficient()
elif xformers:
replace_unet_cross_attn_to_xformers()
replace_unet_cross_attn_to_xformers(unet)
def replace_unet_cross_attn_to_memory_efficient():
@@ -357,50 +357,55 @@ def replace_unet_cross_attn_to_memory_efficient():
out = self.to_out[1](out)
return out
diffusers.models.attention.CrossAttention.forward = forward_flash_attn
# diffusers.models.attention.CrossAttention.forward = forward_flash_attn
# TODO U-Net側に移す
from library.original_unet import CrossAttention
CrossAttention.forward = forward_flash_attn
def replace_unet_cross_attn_to_xformers():
def replace_unet_cross_attn_to_xformers(unet:UNet2DConditionModel):
print("CrossAttention.forward has been replaced to enable xformers and NAI style Hypernetwork")
try:
import xformers.ops
except ImportError:
raise ImportError("No xformers / xformersがインストールされていないようです")
unet.set_use_memory_efficient_attention_xformers(True)
def forward_xformers(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
# def forward_xformers(self, x, context=None, mask=None):
# h = self.heads
# q_in = self.to_q(x)
context = default(context, x)
context = context.to(x.dtype)
# context = default(context, x)
# context = context.to(x.dtype)
if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
context_k, context_v = self.hypernetwork.forward(x, context)
context_k = context_k.to(x.dtype)
context_v = context_v.to(x.dtype)
else:
context_k = context
context_v = context
# if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
# context_k, context_v = self.hypernetwork.forward(x, context)
# context_k = context_k.to(x.dtype)
# context_v = context_v.to(x.dtype)
# else:
# context_k = context
# context_v = context
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
# k_in = self.to_k(context_k)
# v_in = self.to_v(context_v)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
# del q_in, k_in, v_in
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
# q = q.contiguous()
# k = k.contiguous()
# v = v.contiguous()
# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
out = rearrange(out, "b n h d -> b n (h d)", h=h)
# out = rearrange(out, "b n h d -> b n (h d)", h=h)
# diffusers 0.7.0~
out = self.to_out[0](out)
out = self.to_out[1](out)
return out
# # diffusers 0.7.0~
# out = self.to_out[0](out)
# out = self.to_out[1](out)
# return out
diffusers.models.attention.CrossAttention.forward = forward_xformers
# diffusers.models.attention.CrossAttention.forward = forward_xformers
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):