mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
use xformers in VAE in gen script
This commit is contained in:
@@ -1765,6 +1765,7 @@ class FlashAttentionFunction(torch.autograd.function.Function):
|
||||
|
||||
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
# unet is not used currently, but it is here for future use
|
||||
if mem_eff_attn:
|
||||
replace_unet_cross_attn_to_memory_efficient()
|
||||
elif xformers:
|
||||
@@ -1772,7 +1773,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_memory_efficient():
|
||||
print("Replace CrossAttention.forward to use FlashAttention (not xformers)")
|
||||
print("CrossAttention.forward has been replaced to FlashAttention (not xformers)")
|
||||
flash_func = FlashAttentionFunction
|
||||
|
||||
def forward_flash_attn(self, x, context=None, mask=None):
|
||||
@@ -1812,7 +1813,7 @@ def replace_unet_cross_attn_to_memory_efficient():
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_xformers():
|
||||
print("Replace CrossAttention.forward to use xformers")
|
||||
print("CrossAttention.forward has been replaced to enable xformers.")
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
@@ -1854,6 +1855,60 @@ def replace_unet_cross_attn_to_xformers():
|
||||
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||
|
||||
|
||||
"""
|
||||
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
||||
# vae is not used currently, but it is here for future use
|
||||
if mem_eff_attn:
|
||||
replace_vae_attn_to_memory_efficient()
|
||||
elif xformers:
|
||||
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
|
||||
print("Use Diffusers xformers for VAE")
|
||||
vae.encoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True)
|
||||
vae.decoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
|
||||
def replace_vae_attn_to_memory_efficient():
|
||||
print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)")
|
||||
flash_func = FlashAttentionFunction
|
||||
|
||||
def forward_flash_attn(self, hidden_states):
|
||||
print("forward_flash_attn")
|
||||
q_bucket_size = 512
|
||||
k_bucket_size = 1024
|
||||
|
||||
residual = hidden_states
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
|
||||
# norm
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
||||
|
||||
# proj to q, k, v
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
|
||||
query_proj, key_proj, value_proj = map(
|
||||
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
|
||||
)
|
||||
|
||||
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||
|
||||
# res connect and rescale
|
||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn
|
||||
"""
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user