support memory efficient attn (not xformers)

This commit is contained in:
ykume
2023-06-11 16:54:41 +09:00
parent cc274fb7fb
commit 7f6b581ef8

View File

@@ -131,6 +131,187 @@ DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDo
UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"] UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"]
# region memory effcient attention
# FlashAttentionを使うCrossAttention
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
# constants
EPSILON = 1e-6
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# flash attention forwards and backwards
# https://arxiv.org/abs/2205.14135
class FlashAttentionFunction(torch.autograd.Function):
@staticmethod
@torch.no_grad()
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
"""Algorithm 2 in the paper"""
device = q.device
dtype = q.dtype
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
o = torch.zeros_like(q)
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
scale = q.shape[-1] ** -0.5
if not exists(mask):
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
else:
mask = rearrange(mask, "b n -> b 1 1 n")
mask = mask.split(q_bucket_size, dim=-1)
row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
mask,
all_row_sums.split(q_bucket_size, dim=-2),
all_row_maxes.split(q_bucket_size, dim=-2),
)
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
)
for k_ind, (kc, vc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
if exists(row_mask):
attn_weights.masked_fill_(~row_mask, max_neg_value)
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
q_start_index - k_start_index + 1
)
attn_weights.masked_fill_(causal_mask, max_neg_value)
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
attn_weights -= block_row_maxes
exp_weights = torch.exp(attn_weights)
if exists(row_mask):
exp_weights.masked_fill_(~row_mask, 0.0)
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
row_maxes.copy_(new_row_maxes)
row_sums.copy_(new_row_sums)
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
return o
@staticmethod
@torch.no_grad()
def backward(ctx, do):
"""Algorithm 4 in the paper"""
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
q, k, v, o, l, m = ctx.saved_tensors
device = q.device
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
do.split(q_bucket_size, dim=-2),
mask,
l.split(q_bucket_size, dim=-2),
m.split(q_bucket_size, dim=-2),
dq.split(q_bucket_size, dim=-2),
)
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
dk.split(k_bucket_size, dim=-2),
dv.split(k_bucket_size, dim=-2),
)
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
q_start_index - k_start_index + 1
)
attn_weights.masked_fill_(causal_mask, max_neg_value)
exp_attn_weights = torch.exp(attn_weights - mc)
if exists(row_mask):
exp_attn_weights.masked_fill_(~row_mask, 0.0)
p = exp_attn_weights / lc
dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc)
dp = einsum("... i d, ... j d -> ... i j", doc, vc)
D = (doc * oc).sum(dim=-1, keepdims=True)
ds = p * scale * (dp - D)
dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc)
dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc)
dqc.add_(dq_chunk)
dkc.add_(dk_chunk)
dvc.add_(dv_chunk)
return dq, dk, dv, None, None, None, None
# endregion
def get_parameter_dtype(parameter: torch.nn.Module): def get_parameter_dtype(parameter: torch.nn.Module):
return next(parameter.parameters()).dtype return next(parameter.parameters()).dtype
@@ -310,7 +491,7 @@ class DownBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def set_use_memory_efficient_attention_xformers(self, value): def set_use_memory_efficient_attention(self, xformers, mem_eff):
pass pass
def forward(self, hidden_states, temb=None): def forward(self, hidden_states, temb=None):
@@ -382,9 +563,11 @@ class CrossAttention(nn.Module):
# no dropout here # no dropout here
self.use_memory_efficient_attention_xformers = False self.use_memory_efficient_attention_xformers = False
self.use_memory_efficient_attention_mem_eff = False
def set_use_memory_efficient_attention_xformers(self, value): def set_use_memory_efficient_attention(self, xformers, mem_eff):
self.use_memory_efficient_attention_xformers = value self.use_memory_efficient_attention_xformers = xformers
self.use_memory_efficient_attention_mem_eff = mem_eff
def reshape_heads_to_batch_dim(self, tensor): def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape batch_size, seq_len, dim = tensor.shape
@@ -403,6 +586,8 @@ class CrossAttention(nn.Module):
def forward(self, hidden_states, context=None, mask=None): def forward(self, hidden_states, context=None, mask=None):
if self.use_memory_efficient_attention_xformers: if self.use_memory_efficient_attention_xformers:
return self.forward_memory_efficient_xformers(hidden_states, context, mask) return self.forward_memory_efficient_xformers(hidden_states, context, mask)
if self.use_memory_efficient_attention_mem_eff:
return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
query = self.to_q(hidden_states) query = self.to_q(hidden_states)
context = context if context is not None else hidden_states context = context if context is not None else hidden_states
@@ -468,6 +653,29 @@ class CrossAttention(nn.Module):
out = self.to_out[0](out) out = self.to_out[0](out)
return out return out
def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
flash_func = FlashAttentionFunction
q_bucket_size = 512
k_bucket_size = 1024
h = self.heads
q = self.to_q(x)
context = context if context is not None else x
context = context.to(x.dtype)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.to_out[0](out)
return out
# feedforward # feedforward
class GEGLU(nn.Module): class GEGLU(nn.Module):
@@ -547,9 +755,9 @@ class BasicTransformerBlock(nn.Module):
# 3. Feed-forward # 3. Feed-forward
self.norm3 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim)
def set_use_memory_efficient_attention_xformers(self, value: bool): def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
self.attn1.set_use_memory_efficient_attention_xformers(value) self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
self.attn2.set_use_memory_efficient_attention_xformers(value) self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
def forward(self, hidden_states, context=None, timestep=None): def forward(self, hidden_states, context=None, timestep=None):
# 1. Self-Attention # 1. Self-Attention
@@ -608,9 +816,9 @@ class Transformer2DModel(nn.Module):
else: else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def set_use_memory_efficient_attention_xformers(self, value): def set_use_memory_efficient_attention(self, xformers, mem_eff):
for transformer in self.transformer_blocks: for transformer in self.transformer_blocks:
transformer.set_use_memory_efficient_attention_xformers(value) transformer.set_use_memory_efficient_attention(xformers, mem_eff)
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
# 1. Input # 1. Input
@@ -689,9 +897,9 @@ class CrossAttnDownBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def set_use_memory_efficient_attention_xformers(self, value): def set_use_memory_efficient_attention(self, xformers, mem_eff):
for attn in self.attentions: for attn in self.attentions:
attn.set_use_memory_efficient_attention_xformers(value) attn.set_use_memory_efficient_attention(xformers, mem_eff)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None): def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = () output_states = ()
@@ -766,9 +974,9 @@ class UNetMidBlock2DCrossAttn(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def set_use_memory_efficient_attention_xformers(self, value): def set_use_memory_efficient_attention(self, xformers, mem_eff):
for attn in self.attentions: for attn in self.attentions:
attn.set_use_memory_efficient_attention_xformers(value) attn.set_use_memory_efficient_attention(xformers, mem_eff)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None): def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
for i, resnet in enumerate(self.resnets): for i, resnet in enumerate(self.resnets):
@@ -868,7 +1076,7 @@ class UpBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def set_use_memory_efficient_attention_xformers(self, value): def set_use_memory_efficient_attention(self, xformers, mem_eff):
pass pass
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
@@ -947,9 +1155,9 @@ class CrossAttnUpBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def set_use_memory_efficient_attention_xformers(self, value): def set_use_memory_efficient_attention(self, xformers, mem_eff):
for attn in self.attentions: for attn in self.attentions:
attn.set_use_memory_efficient_attention_xformers(value) attn.set_use_memory_efficient_attention(xformers, mem_eff)
def forward( def forward(
self, self,
@@ -1185,10 +1393,10 @@ class UNet2DConditionModel(nn.Module):
def disable_gradient_checkpointing(self): def disable_gradient_checkpointing(self):
self.set_gradient_checkpointing(value=False) self.set_gradient_checkpointing(value=False)
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: def set_use_memory_efficient_attention(self, xformers: bool,mem_eff:bool) -> None:
modules = self.down_blocks + [self.mid_block] + self.up_blocks modules = self.down_blocks + [self.mid_block] + self.up_blocks
for module in modules: for module in modules:
module.set_use_memory_efficient_attention_xformers(valid) module.set_use_memory_efficient_attention(xformers,mem_eff)
def set_gradient_checkpointing(self, value=False): def set_gradient_checkpointing(self, value=False):
modules = self.down_blocks + [self.mid_block] + self.up_blocks modules = self.down_blocks + [self.mid_block] + self.up_blocks