diff --git a/library/original_unet.py b/library/original_unet.py index e8920727..0e64280b 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -131,6 +131,187 @@ DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDo UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"] +# region memory effcient attention + +# FlashAttentionを使うCrossAttention +# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py +# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE + +# constants + +EPSILON = 1e-6 + +# helper functions + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + + +class FlashAttentionFunction(torch.autograd.Function): + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + + scale = q.shape[-1] ** -0.5 + + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.0) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.0) + + p = exp_attn_weights / lc + + dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) + dp = einsum("... i d, ... j d -> ... i j", doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +# endregion + + def get_parameter_dtype(parameter: torch.nn.Module): return next(parameter.parameters()).dtype @@ -310,7 +491,7 @@ class DownBlock2D(nn.Module): self.gradient_checkpointing = False - def set_use_memory_efficient_attention_xformers(self, value): + def set_use_memory_efficient_attention(self, xformers, mem_eff): pass def forward(self, hidden_states, temb=None): @@ -382,9 +563,11 @@ class CrossAttention(nn.Module): # no dropout here self.use_memory_efficient_attention_xformers = False + self.use_memory_efficient_attention_mem_eff = False - def set_use_memory_efficient_attention_xformers(self, value): - self.use_memory_efficient_attention_xformers = value + def set_use_memory_efficient_attention(self, xformers, mem_eff): + self.use_memory_efficient_attention_xformers = xformers + self.use_memory_efficient_attention_mem_eff = mem_eff def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape @@ -403,6 +586,8 @@ class CrossAttention(nn.Module): def forward(self, hidden_states, context=None, mask=None): if self.use_memory_efficient_attention_xformers: return self.forward_memory_efficient_xformers(hidden_states, context, mask) + if self.use_memory_efficient_attention_mem_eff: + return self.forward_memory_efficient_mem_eff(hidden_states, context, mask) query = self.to_q(hidden_states) context = context if context is not None else hidden_states @@ -468,6 +653,29 @@ class CrossAttention(nn.Module): out = self.to_out[0](out) return out + def forward_memory_efficient_mem_eff(self, x, context=None, mask=None): + flash_func = FlashAttentionFunction + + q_bucket_size = 512 + k_bucket_size = 1024 + + h = self.heads + q = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + out = self.to_out[0](out) + return out + # feedforward class GEGLU(nn.Module): @@ -547,9 +755,9 @@ class BasicTransformerBlock(nn.Module): # 3. Feed-forward self.norm3 = nn.LayerNorm(dim) - def set_use_memory_efficient_attention_xformers(self, value: bool): - self.attn1.set_use_memory_efficient_attention_xformers(value) - self.attn2.set_use_memory_efficient_attention_xformers(value) + def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool): + self.attn1.set_use_memory_efficient_attention(xformers, mem_eff) + self.attn2.set_use_memory_efficient_attention(xformers, mem_eff) def forward(self, hidden_states, context=None, timestep=None): # 1. Self-Attention @@ -608,9 +816,9 @@ class Transformer2DModel(nn.Module): else: self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - def set_use_memory_efficient_attention_xformers(self, value): + def set_use_memory_efficient_attention(self, xformers, mem_eff): for transformer in self.transformer_blocks: - transformer.set_use_memory_efficient_attention_xformers(value) + transformer.set_use_memory_efficient_attention(xformers, mem_eff) def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): # 1. Input @@ -689,9 +897,9 @@ class CrossAttnDownBlock2D(nn.Module): self.gradient_checkpointing = False - def set_use_memory_efficient_attention_xformers(self, value): + def set_use_memory_efficient_attention(self, xformers, mem_eff): for attn in self.attentions: - attn.set_use_memory_efficient_attention_xformers(value) + attn.set_use_memory_efficient_attention(xformers, mem_eff) def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () @@ -766,9 +974,9 @@ class UNetMidBlock2DCrossAttn(nn.Module): self.gradient_checkpointing = False - def set_use_memory_efficient_attention_xformers(self, value): + def set_use_memory_efficient_attention(self, xformers, mem_eff): for attn in self.attentions: - attn.set_use_memory_efficient_attention_xformers(value) + attn.set_use_memory_efficient_attention(xformers, mem_eff) def forward(self, hidden_states, temb=None, encoder_hidden_states=None): for i, resnet in enumerate(self.resnets): @@ -868,7 +1076,7 @@ class UpBlock2D(nn.Module): self.gradient_checkpointing = False - def set_use_memory_efficient_attention_xformers(self, value): + def set_use_memory_efficient_attention(self, xformers, mem_eff): pass def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): @@ -947,9 +1155,9 @@ class CrossAttnUpBlock2D(nn.Module): self.gradient_checkpointing = False - def set_use_memory_efficient_attention_xformers(self, value): + def set_use_memory_efficient_attention(self, xformers, mem_eff): for attn in self.attentions: - attn.set_use_memory_efficient_attention_xformers(value) + attn.set_use_memory_efficient_attention(xformers, mem_eff) def forward( self, @@ -1185,10 +1393,10 @@ class UNet2DConditionModel(nn.Module): def disable_gradient_checkpointing(self): self.set_gradient_checkpointing(value=False) - def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: + def set_use_memory_efficient_attention(self, xformers: bool,mem_eff:bool) -> None: modules = self.down_blocks + [self.mid_block] + self.up_blocks for module in modules: - module.set_use_memory_efficient_attention_xformers(valid) + module.set_use_memory_efficient_attention(xformers,mem_eff) def set_gradient_checkpointing(self, value=False): modules = self.down_blocks + [self.mid_block] + self.up_blocks