mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix eps value, enable xformers, etc.
This commit is contained in:
@@ -1792,7 +1792,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
|
||||
if mem_eff_attn:
|
||||
replace_unet_cross_attn_to_memory_efficient()
|
||||
elif xformers:
|
||||
replace_unet_cross_attn_to_xformers()
|
||||
replace_unet_cross_attn_to_xformers(unet)
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_memory_efficient():
|
||||
@@ -1827,55 +1827,59 @@ def replace_unet_cross_attn_to_memory_efficient():
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
|
||||
# diffusers 0.7.0~ わざわざ変えるなよ (;´Д`)
|
||||
out = self.to_out[0](out)
|
||||
out = self.to_out[1](out)
|
||||
# out = self.to_out[1](out)
|
||||
return out
|
||||
|
||||
diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
||||
# diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
||||
from library.original_unet import CrossAttention
|
||||
|
||||
CrossAttention.forward = forward_flash_attn
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_xformers():
|
||||
def replace_unet_cross_attn_to_xformers(unet):
|
||||
print("CrossAttention.forward has been replaced to enable xformers.")
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||
|
||||
def forward_xformers(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
context = default(context, x)
|
||||
context = context.to(x.dtype)
|
||||
# def forward_xformers(self, x, context=None, mask=None):
|
||||
# h = self.heads
|
||||
# q_in = self.to_q(x)
|
||||
|
||||
if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
||||
context_k, context_v = self.hypernetwork.forward(x, context)
|
||||
context_k = context_k.to(x.dtype)
|
||||
context_v = context_v.to(x.dtype)
|
||||
else:
|
||||
context_k = context
|
||||
context_v = context
|
||||
# context = default(context, x)
|
||||
# context = context.to(x.dtype)
|
||||
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
# if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
||||
# context_k, context_v = self.hypernetwork.forward(x, context)
|
||||
# context_k = context_k.to(x.dtype)
|
||||
# context_v = context_v.to(x.dtype)
|
||||
# else:
|
||||
# context_k = context
|
||||
# context_v = context
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
# k_in = self.to_k(context_k)
|
||||
# v_in = self.to_v(context_v)
|
||||
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
||||
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
||||
# del q_in, k_in, v_in
|
||||
|
||||
out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
||||
# q = q.contiguous()
|
||||
# k = k.contiguous()
|
||||
# v = v.contiguous()
|
||||
# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
||||
|
||||
# diffusers 0.7.0~
|
||||
out = self.to_out[0](out)
|
||||
out = self.to_out[1](out)
|
||||
return out
|
||||
# out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
||||
|
||||
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||
# # diffusers 0.7.0~
|
||||
# out = self.to_out[0](out)
|
||||
# out = self.to_out[1](out)
|
||||
# return out
|
||||
|
||||
# diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||
|
||||
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user