mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support sdpa
This commit is contained in:
@@ -137,7 +137,7 @@ USE_CUTOUTS = False
|
||||
"""
|
||||
|
||||
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||
if mem_eff_attn:
|
||||
print("Enable memory efficient attention for U-Net")
|
||||
|
||||
@@ -151,56 +151,26 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
|
||||
raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||
|
||||
unet.set_use_memory_efficient_attention(True, False)
|
||||
elif sdpa:
|
||||
print("Enable SDPA for U-Net")
|
||||
unet.set_use_memory_efficient_attention(False, False)
|
||||
unet.set_use_sdpa(True)
|
||||
|
||||
|
||||
# TODO common train_util.py
|
||||
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
||||
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa):
|
||||
if mem_eff_attn:
|
||||
replace_vae_attn_to_memory_efficient()
|
||||
elif xformers:
|
||||
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
|
||||
replace_vae_attn_to_xformers()
|
||||
elif sdpa:
|
||||
replace_vae_attn_to_sdpa()
|
||||
|
||||
|
||||
def replace_vae_attn_to_memory_efficient():
|
||||
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
|
||||
flash_func = FlashAttentionFunction
|
||||
|
||||
def forward_flash_attn_0_14(self, hidden_states, **kwargs):
|
||||
q_bucket_size = 512
|
||||
k_bucket_size = 1024
|
||||
|
||||
residual = hidden_states
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
|
||||
# norm
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
||||
|
||||
# proj to q, k, v
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
|
||||
query_proj, key_proj, value_proj = map(
|
||||
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
|
||||
)
|
||||
|
||||
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
|
||||
# compute next hidden_states
|
||||
# linear proj
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||
|
||||
# res connect and rescale
|
||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
def forward_flash_attn(self, hidden_states, **kwargs):
|
||||
q_bucket_size = 512
|
||||
k_bucket_size = 1024
|
||||
@@ -238,6 +208,15 @@ def replace_vae_attn_to_memory_efficient():
|
||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
def forward_flash_attn_0_14(self, hidden_states, **kwargs):
|
||||
if not hasattr(self, "to_q"):
|
||||
self.to_q = self.query
|
||||
self.to_k = self.key
|
||||
self.to_v = self.value
|
||||
self.to_out = [self.proj_attn, torch.nn.Identity()]
|
||||
self.heads = self.num_heads
|
||||
return forward_flash_attn(self, hidden_states, **kwargs)
|
||||
|
||||
if diffusers.__version__ < "0.15.0":
|
||||
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14
|
||||
else:
|
||||
@@ -248,40 +227,6 @@ def replace_vae_attn_to_xformers():
|
||||
print("VAE: Attention.forward has been replaced to xformers")
|
||||
import xformers.ops
|
||||
|
||||
def forward_xformers_0_14(self, hidden_states, **kwargs):
|
||||
residual = hidden_states
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
|
||||
# norm
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
||||
|
||||
# proj to q, k, v
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
|
||||
query_proj, key_proj, value_proj = map(
|
||||
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
|
||||
)
|
||||
|
||||
query_proj = query_proj.contiguous()
|
||||
key_proj = key_proj.contiguous()
|
||||
value_proj = value_proj.contiguous()
|
||||
out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||
|
||||
# res connect and rescale
|
||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
def forward_xformers(self, hidden_states, **kwargs):
|
||||
residual = hidden_states
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
@@ -319,12 +264,75 @@ def replace_vae_attn_to_xformers():
|
||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
def forward_xformers_0_14(self, hidden_states, **kwargs):
|
||||
if not hasattr(self, "to_q"):
|
||||
self.to_q = self.query
|
||||
self.to_k = self.key
|
||||
self.to_v = self.value
|
||||
self.to_out = [self.proj_attn, torch.nn.Identity()]
|
||||
self.heads = self.num_heads
|
||||
return forward_xformers(self, hidden_states, **kwargs)
|
||||
|
||||
if diffusers.__version__ < "0.15.0":
|
||||
diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14
|
||||
else:
|
||||
diffusers.models.attention_processor.Attention.forward = forward_xformers
|
||||
|
||||
|
||||
def replace_vae_attn_to_sdpa():
|
||||
print("VAE: Attention.forward has been replaced to sdpa")
|
||||
|
||||
def forward_sdpa(self, hidden_states, **kwargs):
|
||||
residual = hidden_states
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
|
||||
# norm
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
||||
|
||||
# proj to q, k, v
|
||||
query_proj = self.to_q(hidden_states)
|
||||
key_proj = self.to_k(hidden_states)
|
||||
value_proj = self.to_v(hidden_states)
|
||||
|
||||
query_proj, key_proj, value_proj = map(
|
||||
lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj)
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
out = rearrange(out, "b n h d -> b n (h d)")
|
||||
|
||||
# compute next hidden_states
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||
|
||||
# res connect and rescale
|
||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
def forward_sdpa_0_14(self, hidden_states, **kwargs):
|
||||
if not hasattr(self, "to_q"):
|
||||
self.to_q = self.query
|
||||
self.to_k = self.key
|
||||
self.to_v = self.value
|
||||
self.to_out = [self.proj_attn, torch.nn.Identity()]
|
||||
self.heads = self.num_heads
|
||||
return forward_sdpa(self, hidden_states, **kwargs)
|
||||
|
||||
if diffusers.__version__ < "0.15.0":
|
||||
diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14
|
||||
else:
|
||||
diffusers.models.attention_processor.Attention.forward = forward_sdpa
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正
|
||||
@@ -2082,8 +2090,9 @@ def main(args):
|
||||
|
||||
# xformers、Hypernetwork対応
|
||||
if not args.diffusers_xformers:
|
||||
replace_unet_modules(unet, not args.xformers, args.xformers)
|
||||
replace_vae_modules(vae, not args.xformers, args.xformers)
|
||||
mem_eff = not (args.xformers or args.sdpa)
|
||||
replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa)
|
||||
replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa)
|
||||
|
||||
# tokenizerを読み込む
|
||||
print("loading tokenizer")
|
||||
@@ -3176,6 +3185,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する")
|
||||
parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する")
|
||||
parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する")
|
||||
parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa")
|
||||
parser.add_argument(
|
||||
"--diffusers_xformers",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user