diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 34857af3..71daa9a1 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -161,9 +161,45 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ replace_vae_attn_to_xformers() + def replace_vae_attn_to_memory_efficient(): print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") - flash_func =FlashAttentionFunction + flash_func = FlashAttentionFunction + + def forward_flash_attn_0_14(self, hidden_states, **kwargs): + q_bucket_size = 512 + k_bucket_size = 1024 + + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj) + ) + + out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.proj_attn(hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states def forward_flash_attn(self, hidden_states, **kwargs): q_bucket_size = 512 @@ -202,13 +238,50 @@ def replace_vae_attn_to_memory_efficient(): hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states - diffusers.models.attention_processor.Attention.forward = forward_flash_attn + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_flash_attn def replace_vae_attn_to_xformers(): print("VAE: Attention.forward has been replaced to xformers") import xformers.ops - + + def forward_xformers_0_14(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj) + ) + + query_proj = query_proj.contiguous() + key_proj = key_proj.contiguous() + value_proj = value_proj.contiguous() + out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + def forward_xformers(self, hidden_states, **kwargs): residual = hidden_states batch, channel, height, width = hidden_states.shape @@ -246,7 +319,10 @@ def replace_vae_attn_to_xformers(): hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states - diffusers.models.attention_processor.Attention.forward = forward_xformers + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_xformers # endregion diff --git a/library/model_util.py b/library/model_util.py index 0773188c..7ca8f194 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -4,6 +4,7 @@ import math import os import torch +import diffusers from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel from safetensors.torch import load_file, save_file @@ -127,17 +128,30 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.bias", "group_norm.bias") - new_item = new_item.replace("q.weight", "to_q.weight") - new_item = new_item.replace("q.bias", "to_q.bias") + if diffusers.__version__ < "0.15.0": + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") - new_item = new_item.replace("k.weight", "to_k.weight") - new_item = new_item.replace("k.bias", "to_k.bias") + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") - new_item = new_item.replace("v.weight", "to_v.weight") - new_item = new_item.replace("v.bias", "to_v.bias") + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") - new_item = new_item.replace("proj_out.weight", "to_out.0.weight") - new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + else: + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) @@ -192,7 +206,15 @@ def assign_to_checkpoint( new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear - if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2: + reshaping = False + if diffusers.__version__ < "0.15.0": + if "proj_attn.weight" in new_path: + reshaping = True + else: + if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2: + reshaping = True + + if reshaping: checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] else: checkpoint[new_path] = old_checkpoint[path["old"]] @@ -780,14 +802,24 @@ def convert_vae_state_dict(vae_state_dict): sd_mid_res_prefix = f"mid.block_{i+1}." vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) - vae_conversion_map_attn = [ - # (stable-diffusion, HF Diffusers) - ("norm.", "group_norm."), - ("q.", "to_q."), - ("k.", "to_k."), - ("v.", "to_v."), - ("proj_out.", "to_out.0."), - ] + if diffusers.__version__ < "0.15.0": + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), + ] + else: + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "to_q."), + ("k.", "to_k."), + ("v.", "to_v."), + ("proj_out.", "to_out.0."), + ] mapping = {k: k for k in vae_state_dict.keys()} for k, v in mapping.items():