support both 0.10.2 and 0.17.0 for Diffusers

This commit is contained in:
ykume
2023-06-11 18:54:50 +09:00
parent 0315611b11
commit 4d0c06e397
2 changed files with 129 additions and 21 deletions

View File

@@ -161,9 +161,45 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
replace_vae_attn_to_xformers()
def replace_vae_attn_to_memory_efficient():
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
flash_func =FlashAttentionFunction
flash_func = FlashAttentionFunction
def forward_flash_attn_0_14(self, hidden_states, **kwargs):
q_bucket_size = 512
k_bucket_size = 1024
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
)
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
out = rearrange(out, "b h n d -> b n (h d)")
# compute next hidden_states
# linear proj
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_flash_attn(self, hidden_states, **kwargs):
q_bucket_size = 512
@@ -202,13 +238,50 @@ def replace_vae_attn_to_memory_efficient():
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
diffusers.models.attention_processor.Attention.forward = forward_flash_attn
if diffusers.__version__ < "0.15.0":
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14
else:
diffusers.models.attention_processor.Attention.forward = forward_flash_attn
def replace_vae_attn_to_xformers():
print("VAE: Attention.forward has been replaced to xformers")
import xformers.ops
def forward_xformers_0_14(self, hidden_states, **kwargs):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
)
query_proj = query_proj.contiguous()
key_proj = key_proj.contiguous()
value_proj = value_proj.contiguous()
out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
out = rearrange(out, "b h n d -> b n (h d)")
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_xformers(self, hidden_states, **kwargs):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
@@ -246,7 +319,10 @@ def replace_vae_attn_to_xformers():
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
diffusers.models.attention_processor.Attention.forward = forward_xformers
if diffusers.__version__ < "0.15.0":
diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14
else:
diffusers.models.attention_processor.Attention.forward = forward_xformers
# endregion