mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support memory efficient attn (not xformers)
This commit is contained in:
@@ -131,6 +131,187 @@ DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDo
|
||||
UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"]
|
||||
|
||||
|
||||
# region memory effcient attention
|
||||
|
||||
# FlashAttentionを使うCrossAttention
|
||||
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
||||
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
|
||||
|
||||
# constants
|
||||
|
||||
EPSILON = 1e-6
|
||||
|
||||
# helper functions
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
# flash attention forwards and backwards
|
||||
|
||||
# https://arxiv.org/abs/2205.14135
|
||||
|
||||
|
||||
class FlashAttentionFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
||||
"""Algorithm 2 in the paper"""
|
||||
|
||||
device = q.device
|
||||
dtype = q.dtype
|
||||
max_neg_value = -torch.finfo(q.dtype).max
|
||||
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
||||
|
||||
o = torch.zeros_like(q)
|
||||
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
||||
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
|
||||
|
||||
scale = q.shape[-1] ** -0.5
|
||||
|
||||
if not exists(mask):
|
||||
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
||||
else:
|
||||
mask = rearrange(mask, "b n -> b 1 1 n")
|
||||
mask = mask.split(q_bucket_size, dim=-1)
|
||||
|
||||
row_splits = zip(
|
||||
q.split(q_bucket_size, dim=-2),
|
||||
o.split(q_bucket_size, dim=-2),
|
||||
mask,
|
||||
all_row_sums.split(q_bucket_size, dim=-2),
|
||||
all_row_maxes.split(q_bucket_size, dim=-2),
|
||||
)
|
||||
|
||||
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
||||
q_start_index = ind * q_bucket_size - qk_len_diff
|
||||
|
||||
col_splits = zip(
|
||||
k.split(k_bucket_size, dim=-2),
|
||||
v.split(k_bucket_size, dim=-2),
|
||||
)
|
||||
|
||||
for k_ind, (kc, vc) in enumerate(col_splits):
|
||||
k_start_index = k_ind * k_bucket_size
|
||||
|
||||
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
||||
|
||||
if exists(row_mask):
|
||||
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
||||
|
||||
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
||||
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
||||
q_start_index - k_start_index + 1
|
||||
)
|
||||
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
||||
|
||||
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
||||
attn_weights -= block_row_maxes
|
||||
exp_weights = torch.exp(attn_weights)
|
||||
|
||||
if exists(row_mask):
|
||||
exp_weights.masked_fill_(~row_mask, 0.0)
|
||||
|
||||
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
|
||||
|
||||
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
||||
|
||||
exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
|
||||
|
||||
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
||||
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
||||
|
||||
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
|
||||
|
||||
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
|
||||
|
||||
row_maxes.copy_(new_row_maxes)
|
||||
row_sums.copy_(new_row_sums)
|
||||
|
||||
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
||||
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
||||
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def backward(ctx, do):
|
||||
"""Algorithm 4 in the paper"""
|
||||
|
||||
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
||||
q, k, v, o, l, m = ctx.saved_tensors
|
||||
|
||||
device = q.device
|
||||
|
||||
max_neg_value = -torch.finfo(q.dtype).max
|
||||
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
||||
|
||||
dq = torch.zeros_like(q)
|
||||
dk = torch.zeros_like(k)
|
||||
dv = torch.zeros_like(v)
|
||||
|
||||
row_splits = zip(
|
||||
q.split(q_bucket_size, dim=-2),
|
||||
o.split(q_bucket_size, dim=-2),
|
||||
do.split(q_bucket_size, dim=-2),
|
||||
mask,
|
||||
l.split(q_bucket_size, dim=-2),
|
||||
m.split(q_bucket_size, dim=-2),
|
||||
dq.split(q_bucket_size, dim=-2),
|
||||
)
|
||||
|
||||
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
||||
q_start_index = ind * q_bucket_size - qk_len_diff
|
||||
|
||||
col_splits = zip(
|
||||
k.split(k_bucket_size, dim=-2),
|
||||
v.split(k_bucket_size, dim=-2),
|
||||
dk.split(k_bucket_size, dim=-2),
|
||||
dv.split(k_bucket_size, dim=-2),
|
||||
)
|
||||
|
||||
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
||||
k_start_index = k_ind * k_bucket_size
|
||||
|
||||
attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
||||
|
||||
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
||||
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
||||
q_start_index - k_start_index + 1
|
||||
)
|
||||
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
||||
|
||||
exp_attn_weights = torch.exp(attn_weights - mc)
|
||||
|
||||
if exists(row_mask):
|
||||
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
||||
|
||||
p = exp_attn_weights / lc
|
||||
|
||||
dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc)
|
||||
dp = einsum("... i d, ... j d -> ... i j", doc, vc)
|
||||
|
||||
D = (doc * oc).sum(dim=-1, keepdims=True)
|
||||
ds = p * scale * (dp - D)
|
||||
|
||||
dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc)
|
||||
dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc)
|
||||
|
||||
dqc.add_(dq_chunk)
|
||||
dkc.add_(dk_chunk)
|
||||
dvc.add_(dv_chunk)
|
||||
|
||||
return dq, dk, dv, None, None, None, None
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def get_parameter_dtype(parameter: torch.nn.Module):
|
||||
return next(parameter.parameters()).dtype
|
||||
|
||||
@@ -310,7 +491,7 @@ class DownBlock2D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, value):
|
||||
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
||||
pass
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
@@ -382,9 +563,11 @@ class CrossAttention(nn.Module):
|
||||
# no dropout here
|
||||
|
||||
self.use_memory_efficient_attention_xformers = False
|
||||
self.use_memory_efficient_attention_mem_eff = False
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, value):
|
||||
self.use_memory_efficient_attention_xformers = value
|
||||
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
||||
self.use_memory_efficient_attention_xformers = xformers
|
||||
self.use_memory_efficient_attention_mem_eff = mem_eff
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
@@ -403,6 +586,8 @@ class CrossAttention(nn.Module):
|
||||
def forward(self, hidden_states, context=None, mask=None):
|
||||
if self.use_memory_efficient_attention_xformers:
|
||||
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
|
||||
if self.use_memory_efficient_attention_mem_eff:
|
||||
return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
context = context if context is not None else hidden_states
|
||||
@@ -468,6 +653,29 @@ class CrossAttention(nn.Module):
|
||||
out = self.to_out[0](out)
|
||||
return out
|
||||
|
||||
def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
|
||||
flash_func = FlashAttentionFunction
|
||||
|
||||
q_bucket_size = 512
|
||||
k_bucket_size = 1024
|
||||
|
||||
h = self.heads
|
||||
q = self.to_q(x)
|
||||
context = context if context is not None else x
|
||||
context = context.to(x.dtype)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
||||
|
||||
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
|
||||
out = self.to_out[0](out)
|
||||
return out
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
@@ -547,9 +755,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, value: bool):
|
||||
self.attn1.set_use_memory_efficient_attention_xformers(value)
|
||||
self.attn2.set_use_memory_efficient_attention_xformers(value)
|
||||
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
|
||||
self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||
self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||
|
||||
def forward(self, hidden_states, context=None, timestep=None):
|
||||
# 1. Self-Attention
|
||||
@@ -608,9 +816,9 @@ class Transformer2DModel(nn.Module):
|
||||
else:
|
||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, value):
|
||||
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
||||
for transformer in self.transformer_blocks:
|
||||
transformer.set_use_memory_efficient_attention_xformers(value)
|
||||
transformer.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
||||
# 1. Input
|
||||
@@ -689,9 +897,9 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, value):
|
||||
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
||||
for attn in self.attentions:
|
||||
attn.set_use_memory_efficient_attention_xformers(value)
|
||||
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
output_states = ()
|
||||
@@ -766,9 +974,9 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, value):
|
||||
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
||||
for attn in self.attentions:
|
||||
attn.set_use_memory_efficient_attention_xformers(value)
|
||||
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
@@ -868,7 +1076,7 @@ class UpBlock2D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, value):
|
||||
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
||||
pass
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
||||
@@ -947,9 +1155,9 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, value):
|
||||
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
||||
for attn in self.attentions:
|
||||
attn.set_use_memory_efficient_attention_xformers(value)
|
||||
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -1185,10 +1393,10 @@ class UNet2DConditionModel(nn.Module):
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.set_gradient_checkpointing(value=False)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
|
||||
def set_use_memory_efficient_attention(self, xformers: bool,mem_eff:bool) -> None:
|
||||
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
||||
for module in modules:
|
||||
module.set_use_memory_efficient_attention_xformers(valid)
|
||||
module.set_use_memory_efficient_attention(xformers,mem_eff)
|
||||
|
||||
def set_gradient_checkpointing(self, value=False):
|
||||
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
||||
|
||||
Reference in New Issue
Block a user