mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support sdpa
This commit is contained in:
@@ -494,6 +494,9 @@ class DownBlock2D(nn.Module):
|
||||
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
||||
pass
|
||||
|
||||
def set_use_sdpa(self, sdpa):
|
||||
pass
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
output_states = ()
|
||||
|
||||
@@ -564,11 +567,15 @@ class CrossAttention(nn.Module):
|
||||
|
||||
self.use_memory_efficient_attention_xformers = False
|
||||
self.use_memory_efficient_attention_mem_eff = False
|
||||
self.use_sdpa = False
|
||||
|
||||
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
||||
self.use_memory_efficient_attention_xformers = xformers
|
||||
self.use_memory_efficient_attention_mem_eff = mem_eff
|
||||
|
||||
def set_use_sdpa(self, sdpa):
|
||||
self.use_sdpa = sdpa
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
@@ -588,6 +595,8 @@ class CrossAttention(nn.Module):
|
||||
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
|
||||
if self.use_memory_efficient_attention_mem_eff:
|
||||
return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
|
||||
if self.use_sdpa:
|
||||
return self.forward_sdpa(hidden_states, context, mask)
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
context = context if context is not None else hidden_states
|
||||
@@ -676,6 +685,26 @@ class CrossAttention(nn.Module):
|
||||
out = self.to_out[0](out)
|
||||
return out
|
||||
|
||||
def forward_sdpa(self, x, context=None, mask=None):
|
||||
import xformers.ops
|
||||
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
context = context if context is not None else x
|
||||
context = context.to(x.dtype)
|
||||
k_in = self.to_k(context)
|
||||
v_in = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
||||
|
||||
out = self.to_out[0](out)
|
||||
return out
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
@@ -759,6 +788,10 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||
self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||
|
||||
def set_use_sdpa(self, sdpa: bool):
|
||||
self.attn1.set_use_sdpa(sdpa)
|
||||
self.attn2.set_use_sdpa(sdpa)
|
||||
|
||||
def forward(self, hidden_states, context=None, timestep=None):
|
||||
# 1. Self-Attention
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
@@ -820,6 +853,10 @@ class Transformer2DModel(nn.Module):
|
||||
for transformer in self.transformer_blocks:
|
||||
transformer.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||
|
||||
def set_use_sdpa(self, sdpa):
|
||||
for transformer in self.transformer_blocks:
|
||||
transformer.set_use_sdpa(sdpa)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
||||
# 1. Input
|
||||
batch, _, height, weight = hidden_states.shape
|
||||
@@ -901,6 +938,10 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
for attn in self.attentions:
|
||||
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||
|
||||
def set_use_sdpa(self, sdpa):
|
||||
for attn in self.attentions:
|
||||
attn.set_use_sdpa(sdpa)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
output_states = ()
|
||||
|
||||
@@ -978,6 +1019,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
for attn in self.attentions:
|
||||
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||
|
||||
def set_use_sdpa(self, sdpa):
|
||||
for attn in self.attentions:
|
||||
attn.set_use_sdpa(sdpa)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
attn = None if i == 0 else self.attentions[i - 1]
|
||||
@@ -1079,6 +1124,9 @@ class UpBlock2D(nn.Module):
|
||||
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
||||
pass
|
||||
|
||||
def set_use_sdpa(self, sdpa):
|
||||
pass
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
@@ -1159,6 +1207,10 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
for attn in self.attentions:
|
||||
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||
|
||||
def set_use_sdpa(self, spda):
|
||||
for attn in self.attentions:
|
||||
attn.set_use_sdpa(spda)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@@ -1393,10 +1445,15 @@ class UNet2DConditionModel(nn.Module):
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.set_gradient_checkpointing(value=False)
|
||||
|
||||
def set_use_memory_efficient_attention(self, xformers: bool,mem_eff:bool) -> None:
|
||||
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
|
||||
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
||||
for module in modules:
|
||||
module.set_use_memory_efficient_attention(xformers,mem_eff)
|
||||
module.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||
|
||||
def set_use_sdpa(self, sdpa: bool) -> None:
|
||||
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
||||
for module in modules:
|
||||
module.set_use_sdpa(sdpa)
|
||||
|
||||
def set_gradient_checkpointing(self, value=False):
|
||||
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
||||
|
||||
@@ -1788,7 +1788,7 @@ class FlashAttentionFunction(torch.autograd.function.Function):
|
||||
return dq, dk, dv, None, None, None, None
|
||||
|
||||
|
||||
def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||
if mem_eff_attn:
|
||||
print("Enable memory efficient attention for U-Net")
|
||||
unet.set_use_memory_efficient_attention(False, True)
|
||||
@@ -1800,6 +1800,9 @@ def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||
|
||||
unet.set_use_memory_efficient_attention(True, False)
|
||||
elif sdpa:
|
||||
print("Enable SDPA for U-Net")
|
||||
unet.set_use_sdpa(True)
|
||||
|
||||
"""
|
||||
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
||||
@@ -2048,6 +2051,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
|
||||
)
|
||||
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||
parser.add_argument("--sdpa", action="store_true", help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)")
|
||||
parser.add_argument(
|
||||
"--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user