support both 0.10.2 and 0.17.0 for Diffusers

This commit is contained in:
ykume
2023-06-11 18:54:50 +09:00
parent 0315611b11
commit 4d0c06e397
2 changed files with 129 additions and 21 deletions

View File

@@ -161,9 +161,45 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
replace_vae_attn_to_xformers() replace_vae_attn_to_xformers()
def replace_vae_attn_to_memory_efficient(): def replace_vae_attn_to_memory_efficient():
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") print("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
flash_func =FlashAttentionFunction flash_func = FlashAttentionFunction
def forward_flash_attn_0_14(self, hidden_states, **kwargs):
q_bucket_size = 512
k_bucket_size = 1024
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
)
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
out = rearrange(out, "b h n d -> b n (h d)")
# compute next hidden_states
# linear proj
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_flash_attn(self, hidden_states, **kwargs): def forward_flash_attn(self, hidden_states, **kwargs):
q_bucket_size = 512 q_bucket_size = 512
@@ -202,13 +238,50 @@ def replace_vae_attn_to_memory_efficient():
hidden_states = (hidden_states + residual) / self.rescale_output_factor hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states return hidden_states
diffusers.models.attention_processor.Attention.forward = forward_flash_attn if diffusers.__version__ < "0.15.0":
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14
else:
diffusers.models.attention_processor.Attention.forward = forward_flash_attn
def replace_vae_attn_to_xformers(): def replace_vae_attn_to_xformers():
print("VAE: Attention.forward has been replaced to xformers") print("VAE: Attention.forward has been replaced to xformers")
import xformers.ops import xformers.ops
def forward_xformers_0_14(self, hidden_states, **kwargs):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
)
query_proj = query_proj.contiguous()
key_proj = key_proj.contiguous()
value_proj = value_proj.contiguous()
out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
out = rearrange(out, "b h n d -> b n (h d)")
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_xformers(self, hidden_states, **kwargs): def forward_xformers(self, hidden_states, **kwargs):
residual = hidden_states residual = hidden_states
batch, channel, height, width = hidden_states.shape batch, channel, height, width = hidden_states.shape
@@ -246,7 +319,10 @@ def replace_vae_attn_to_xformers():
hidden_states = (hidden_states + residual) / self.rescale_output_factor hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states return hidden_states
diffusers.models.attention_processor.Attention.forward = forward_xformers if diffusers.__version__ < "0.15.0":
diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14
else:
diffusers.models.attention_processor.Attention.forward = forward_xformers
# endregion # endregion

View File

@@ -4,6 +4,7 @@
import math import math
import os import os
import torch import torch
import diffusers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
@@ -127,17 +128,30 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias") new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "to_q.weight") if diffusers.__version__ < "0.15.0":
new_item = new_item.replace("q.bias", "to_q.bias") new_item = new_item.replace("q.weight", "query.weight")
new_item = new_item.replace("q.bias", "query.bias")
new_item = new_item.replace("k.weight", "to_k.weight") new_item = new_item.replace("k.weight", "key.weight")
new_item = new_item.replace("k.bias", "to_k.bias") new_item = new_item.replace("k.bias", "key.bias")
new_item = new_item.replace("v.weight", "to_v.weight") new_item = new_item.replace("v.weight", "value.weight")
new_item = new_item.replace("v.bias", "to_v.bias") new_item = new_item.replace("v.bias", "value.bias")
new_item = new_item.replace("proj_out.weight", "to_out.0.weight") new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "to_out.0.bias") new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
else:
new_item = new_item.replace("q.weight", "to_q.weight")
new_item = new_item.replace("q.bias", "to_q.bias")
new_item = new_item.replace("k.weight", "to_k.weight")
new_item = new_item.replace("k.bias", "to_k.bias")
new_item = new_item.replace("v.weight", "to_v.weight")
new_item = new_item.replace("v.bias", "to_v.bias")
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
@@ -192,7 +206,15 @@ def assign_to_checkpoint(
new_path = new_path.replace(replacement["old"], replacement["new"]) new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear # proj_attn.weight has to be converted from conv 1D to linear
if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2: reshaping = False
if diffusers.__version__ < "0.15.0":
if "proj_attn.weight" in new_path:
reshaping = True
else:
if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
reshaping = True
if reshaping:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
else: else:
checkpoint[new_path] = old_checkpoint[path["old"]] checkpoint[new_path] = old_checkpoint[path["old"]]
@@ -780,14 +802,24 @@ def convert_vae_state_dict(vae_state_dict):
sd_mid_res_prefix = f"mid.block_{i+1}." sd_mid_res_prefix = f"mid.block_{i+1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
vae_conversion_map_attn = [ if diffusers.__version__ < "0.15.0":
# (stable-diffusion, HF Diffusers) vae_conversion_map_attn = [
("norm.", "group_norm."), # (stable-diffusion, HF Diffusers)
("q.", "to_q."), ("norm.", "group_norm."),
("k.", "to_k."), ("q.", "query."),
("v.", "to_v."), ("k.", "key."),
("proj_out.", "to_out.0."), ("v.", "value."),
] ("proj_out.", "proj_attn."),
]
else:
vae_conversion_map_attn = [
# (stable-diffusion, HF Diffusers)
("norm.", "group_norm."),
("q.", "to_q."),
("k.", "to_k."),
("v.", "to_v."),
("proj_out.", "to_out.0."),
]
mapping = {k: k for k in vae_state_dict.keys()} mapping = {k: k for k in vae_state_dict.keys()}
for k, v in mapping.items(): for k, v in mapping.items():