mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix eps value, enable xformers, etc.
This commit is contained in:
@@ -317,7 +317,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
|
||||
if mem_eff_attn:
|
||||
replace_unet_cross_attn_to_memory_efficient()
|
||||
elif xformers:
|
||||
replace_unet_cross_attn_to_xformers()
|
||||
replace_unet_cross_attn_to_xformers(unet)
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_memory_efficient():
|
||||
@@ -357,50 +357,55 @@ def replace_unet_cross_attn_to_memory_efficient():
|
||||
out = self.to_out[1](out)
|
||||
return out
|
||||
|
||||
diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
||||
# diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
||||
# TODO U-Net側に移す
|
||||
from library.original_unet import CrossAttention
|
||||
CrossAttention.forward = forward_flash_attn
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_xformers():
|
||||
def replace_unet_cross_attn_to_xformers(unet:UNet2DConditionModel):
|
||||
print("CrossAttention.forward has been replaced to enable xformers and NAI style Hypernetwork")
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||
|
||||
unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
def forward_xformers(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
# def forward_xformers(self, x, context=None, mask=None):
|
||||
# h = self.heads
|
||||
# q_in = self.to_q(x)
|
||||
|
||||
context = default(context, x)
|
||||
context = context.to(x.dtype)
|
||||
# context = default(context, x)
|
||||
# context = context.to(x.dtype)
|
||||
|
||||
if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
||||
context_k, context_v = self.hypernetwork.forward(x, context)
|
||||
context_k = context_k.to(x.dtype)
|
||||
context_v = context_v.to(x.dtype)
|
||||
else:
|
||||
context_k = context
|
||||
context_v = context
|
||||
# if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
||||
# context_k, context_v = self.hypernetwork.forward(x, context)
|
||||
# context_k = context_k.to(x.dtype)
|
||||
# context_v = context_v.to(x.dtype)
|
||||
# else:
|
||||
# context_k = context
|
||||
# context_v = context
|
||||
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
# k_in = self.to_k(context_k)
|
||||
# v_in = self.to_v(context_v)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
||||
# del q_in, k_in, v_in
|
||||
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
||||
# q = q.contiguous()
|
||||
# k = k.contiguous()
|
||||
# v = v.contiguous()
|
||||
# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
||||
|
||||
out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
||||
# out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
||||
|
||||
# diffusers 0.7.0~
|
||||
out = self.to_out[0](out)
|
||||
out = self.to_out[1](out)
|
||||
return out
|
||||
# # diffusers 0.7.0~
|
||||
# out = self.to_out[0](out)
|
||||
# out = self.to_out[1](out)
|
||||
# return out
|
||||
|
||||
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||
# diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||
|
||||
|
||||
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
||||
|
||||
Reference in New Issue
Block a user