mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support both 0.10.2 and 0.17.0 for Diffusers
This commit is contained in:
@@ -161,10 +161,46 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
|
|||||||
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
|
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
|
||||||
replace_vae_attn_to_xformers()
|
replace_vae_attn_to_xformers()
|
||||||
|
|
||||||
|
|
||||||
def replace_vae_attn_to_memory_efficient():
|
def replace_vae_attn_to_memory_efficient():
|
||||||
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
|
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
|
||||||
flash_func = FlashAttentionFunction
|
flash_func = FlashAttentionFunction
|
||||||
|
|
||||||
|
def forward_flash_attn_0_14(self, hidden_states, **kwargs):
|
||||||
|
q_bucket_size = 512
|
||||||
|
k_bucket_size = 1024
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
batch, channel, height, width = hidden_states.shape
|
||||||
|
|
||||||
|
# norm
|
||||||
|
hidden_states = self.group_norm(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
# proj to q, k, v
|
||||||
|
query_proj = self.query(hidden_states)
|
||||||
|
key_proj = self.key(hidden_states)
|
||||||
|
value_proj = self.value(hidden_states)
|
||||||
|
|
||||||
|
query_proj, key_proj, value_proj = map(
|
||||||
|
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
|
||||||
|
)
|
||||||
|
|
||||||
|
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
|
||||||
|
|
||||||
|
out = rearrange(out, "b h n d -> b n (h d)")
|
||||||
|
|
||||||
|
# compute next hidden_states
|
||||||
|
# linear proj
|
||||||
|
hidden_states = self.proj_attn(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||||
|
|
||||||
|
# res connect and rescale
|
||||||
|
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
def forward_flash_attn(self, hidden_states, **kwargs):
|
def forward_flash_attn(self, hidden_states, **kwargs):
|
||||||
q_bucket_size = 512
|
q_bucket_size = 512
|
||||||
k_bucket_size = 1024
|
k_bucket_size = 1024
|
||||||
@@ -202,6 +238,9 @@ def replace_vae_attn_to_memory_efficient():
|
|||||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
if diffusers.__version__ < "0.15.0":
|
||||||
|
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14
|
||||||
|
else:
|
||||||
diffusers.models.attention_processor.Attention.forward = forward_flash_attn
|
diffusers.models.attention_processor.Attention.forward = forward_flash_attn
|
||||||
|
|
||||||
|
|
||||||
@@ -209,6 +248,40 @@ def replace_vae_attn_to_xformers():
|
|||||||
print("VAE: Attention.forward has been replaced to xformers")
|
print("VAE: Attention.forward has been replaced to xformers")
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
|
def forward_xformers_0_14(self, hidden_states, **kwargs):
|
||||||
|
residual = hidden_states
|
||||||
|
batch, channel, height, width = hidden_states.shape
|
||||||
|
|
||||||
|
# norm
|
||||||
|
hidden_states = self.group_norm(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
# proj to q, k, v
|
||||||
|
query_proj = self.query(hidden_states)
|
||||||
|
key_proj = self.key(hidden_states)
|
||||||
|
value_proj = self.value(hidden_states)
|
||||||
|
|
||||||
|
query_proj, key_proj, value_proj = map(
|
||||||
|
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
|
||||||
|
)
|
||||||
|
|
||||||
|
query_proj = query_proj.contiguous()
|
||||||
|
key_proj = key_proj.contiguous()
|
||||||
|
value_proj = value_proj.contiguous()
|
||||||
|
out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
|
||||||
|
|
||||||
|
out = rearrange(out, "b h n d -> b n (h d)")
|
||||||
|
|
||||||
|
# compute next hidden_states
|
||||||
|
hidden_states = self.proj_attn(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||||
|
|
||||||
|
# res connect and rescale
|
||||||
|
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
def forward_xformers(self, hidden_states, **kwargs):
|
def forward_xformers(self, hidden_states, **kwargs):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
batch, channel, height, width = hidden_states.shape
|
batch, channel, height, width = hidden_states.shape
|
||||||
@@ -246,6 +319,9 @@ def replace_vae_attn_to_xformers():
|
|||||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
if diffusers.__version__ < "0.15.0":
|
||||||
|
diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14
|
||||||
|
else:
|
||||||
diffusers.models.attention_processor.Attention.forward = forward_xformers
|
diffusers.models.attention_processor.Attention.forward = forward_xformers
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
import diffusers
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
||||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
@@ -127,6 +128,19 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
|||||||
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
||||||
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
||||||
|
|
||||||
|
if diffusers.__version__ < "0.15.0":
|
||||||
|
new_item = new_item.replace("q.weight", "query.weight")
|
||||||
|
new_item = new_item.replace("q.bias", "query.bias")
|
||||||
|
|
||||||
|
new_item = new_item.replace("k.weight", "key.weight")
|
||||||
|
new_item = new_item.replace("k.bias", "key.bias")
|
||||||
|
|
||||||
|
new_item = new_item.replace("v.weight", "value.weight")
|
||||||
|
new_item = new_item.replace("v.bias", "value.bias")
|
||||||
|
|
||||||
|
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
||||||
|
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
||||||
|
else:
|
||||||
new_item = new_item.replace("q.weight", "to_q.weight")
|
new_item = new_item.replace("q.weight", "to_q.weight")
|
||||||
new_item = new_item.replace("q.bias", "to_q.bias")
|
new_item = new_item.replace("q.bias", "to_q.bias")
|
||||||
|
|
||||||
@@ -192,7 +206,15 @@ def assign_to_checkpoint(
|
|||||||
new_path = new_path.replace(replacement["old"], replacement["new"])
|
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||||
|
|
||||||
# proj_attn.weight has to be converted from conv 1D to linear
|
# proj_attn.weight has to be converted from conv 1D to linear
|
||||||
|
reshaping = False
|
||||||
|
if diffusers.__version__ < "0.15.0":
|
||||||
|
if "proj_attn.weight" in new_path:
|
||||||
|
reshaping = True
|
||||||
|
else:
|
||||||
if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
|
if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
|
||||||
|
reshaping = True
|
||||||
|
|
||||||
|
if reshaping:
|
||||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
|
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
|
||||||
else:
|
else:
|
||||||
checkpoint[new_path] = old_checkpoint[path["old"]]
|
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||||
@@ -780,6 +802,16 @@ def convert_vae_state_dict(vae_state_dict):
|
|||||||
sd_mid_res_prefix = f"mid.block_{i+1}."
|
sd_mid_res_prefix = f"mid.block_{i+1}."
|
||||||
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||||
|
|
||||||
|
if diffusers.__version__ < "0.15.0":
|
||||||
|
vae_conversion_map_attn = [
|
||||||
|
# (stable-diffusion, HF Diffusers)
|
||||||
|
("norm.", "group_norm."),
|
||||||
|
("q.", "query."),
|
||||||
|
("k.", "key."),
|
||||||
|
("v.", "value."),
|
||||||
|
("proj_out.", "proj_attn."),
|
||||||
|
]
|
||||||
|
else:
|
||||||
vae_conversion_map_attn = [
|
vae_conversion_map_attn = [
|
||||||
# (stable-diffusion, HF Diffusers)
|
# (stable-diffusion, HF Diffusers)
|
||||||
("norm.", "group_norm."),
|
("norm.", "group_norm."),
|
||||||
|
|||||||
Reference in New Issue
Block a user