From c0a7df9ee14c715134f0e1cfce0a6256d1b64014 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 3 Jun 2023 21:29:27 +0900 Subject: [PATCH] fix eps value, enable xformers, etc. --- gen_img_diffusers.py | 63 ++++++++++---------- library/original_unet.py | 120 +++++++++++++++++++++++++++++++++------ library/train_util.py | 66 +++++++++++---------- 3 files changed, 171 insertions(+), 78 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 4d8121ca..33b7a65c 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -317,7 +317,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio if mem_eff_attn: replace_unet_cross_attn_to_memory_efficient() elif xformers: - replace_unet_cross_attn_to_xformers() + replace_unet_cross_attn_to_xformers(unet) def replace_unet_cross_attn_to_memory_efficient(): @@ -357,50 +357,55 @@ def replace_unet_cross_attn_to_memory_efficient(): out = self.to_out[1](out) return out - diffusers.models.attention.CrossAttention.forward = forward_flash_attn + # diffusers.models.attention.CrossAttention.forward = forward_flash_attn + # TODO U-Net側に移す + from library.original_unet import CrossAttention + CrossAttention.forward = forward_flash_attn -def replace_unet_cross_attn_to_xformers(): +def replace_unet_cross_attn_to_xformers(unet:UNet2DConditionModel): print("CrossAttention.forward has been replaced to enable xformers and NAI style Hypernetwork") try: import xformers.ops except ImportError: raise ImportError("No xformers / xformersがインストールされていないようです") + + unet.set_use_memory_efficient_attention_xformers(True) - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) + # def forward_xformers(self, x, context=None, mask=None): + # h = self.heads + # q_in = self.to_q(x) - context = default(context, x) - context = context.to(x.dtype) + # context = default(context, x) + # context = context.to(x.dtype) - if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + # if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + # context_k, context_v = self.hypernetwork.forward(x, context) + # context_k = context_k.to(x.dtype) + # context_v = context_v.to(x.dtype) + # else: + # context_k = context + # context_v = context - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) + # k_in = self.to_k(context_k) + # v_in = self.to_v(context_v) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in + # q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + # del q_in, k_in, v_in - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + # q = q.contiguous() + # k = k.contiguous() + # v = v.contiguous() + # out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - out = rearrange(out, "b n h d -> b n (h d)", h=h) + # out = rearrange(out, "b n h d -> b n (h d)", h=h) - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out + # # diffusers 0.7.0~ + # out = self.to_out[0](out) + # out = self.to_out[1](out) + # return out - diffusers.models.attention.CrossAttention.forward = forward_xformers + # diffusers.models.attention.CrossAttention.forward = forward_xformers def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): diff --git a/library/original_unet.py b/library/original_unet.py index 603239d5..47b751c1 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -1,10 +1,10 @@ # Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる +# 条件分岐等で不要な部分は削除している # コードの多くはDiffusersからコピーしている -# コードが冗長になる部分はコメント等を適宜削除する # 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある # Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers. -# Remove redundant code by deleting comments, etc. as appropriate +# Unnecessary parts are deleted by condition branching. # As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2 """ @@ -111,6 +111,7 @@ from typing import Dict, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F +from einops import rearrange BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0] @@ -121,8 +122,8 @@ LAYERS_PER_BLOCK: int = 2 LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1 TIME_EMBED_FLIP_SIN_TO_COS: bool = True TIME_EMBED_FREQ_SHIFT: int = 0 -RESNET_GROUPS: int = 32 -RESNET_EPS: float = 1e-6 +NORM_GROUPS: int = 32 +NORM_EPS: float = 1e-5 TRANSFORMER_NORM_NUM_GROUPS = 32 DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"] @@ -233,13 +234,13 @@ class ResnetBlock2D(nn.Module): self.in_channels = in_channels self.out_channels = out_channels - self.norm1 = torch.nn.GroupNorm(num_groups=RESNET_GROUPS, num_channels=in_channels, eps=RESNET_EPS, affine=True) + self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels) - self.norm2 = torch.nn.GroupNorm(num_groups=RESNET_GROUPS, num_channels=out_channels, eps=RESNET_EPS, affine=True) + self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) # if non_linearity == "swish": @@ -304,6 +305,9 @@ class DownBlock2D(nn.Module): self.gradient_checkpointing = False + def set_use_memory_efficient_attention_xformers(self, value): + pass + def forward(self, hidden_states, temb=None): output_states = () @@ -372,6 +376,11 @@ class CrossAttention(nn.Module): self.to_out.append(nn.Linear(inner_dim, query_dim)) # no dropout here + self.use_memory_efficient_attention_xformers = False + + def set_use_memory_efficient_attention_xformers(self, value): + self.use_memory_efficient_attention_xformers = value + def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads @@ -387,6 +396,9 @@ class CrossAttention(nn.Module): return tensor def forward(self, hidden_states, context=None, mask=None): + if self.use_memory_efficient_attention_xformers: + return self.forward_memory_efficient_xformers(hidden_states, context, mask) + query = self.to_q(hidden_states) context = context if context is not None else hidden_states key = self.to_k(context) @@ -427,6 +439,30 @@ class CrossAttention(nn.Module): hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states + # TODO support Hypernetworks + def forward_memory_efficient_xformers(self, x, context=None, mask=None): + import xformers.ops + + h = self.heads + q_in = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k_in = self.to_k(context) + v_in = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + + out = rearrange(out, "b n h d -> b n (h d)", h=h) + + out = self.to_out[0](out) + return out + # feedforward class GEGLU(nn.Module): @@ -506,8 +542,9 @@ class BasicTransformerBlock(nn.Module): # 3. Feed-forward self.norm3 = nn.LayerNorm(dim) - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - raise NotImplementedError("Memory efficient attention is not implemented for this model.") + def set_use_memory_efficient_attention_xformers(self, value: bool): + self.attn1.set_use_memory_efficient_attention_xformers(value) + self.attn2.set_use_memory_efficient_attention_xformers(value) def forward(self, hidden_states, context=None, timestep=None): # 1. Self-Attention @@ -566,6 +603,10 @@ class Transformer2DModel(nn.Module): else: self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + def set_use_memory_efficient_attention_xformers(self, value): + for transformer in self.transformer_blocks: + transformer.set_use_memory_efficient_attention_xformers(value) + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): # 1. Input batch, _, height, weight = hidden_states.shape @@ -643,6 +684,10 @@ class CrossAttnDownBlock2D(nn.Module): self.gradient_checkpointing = False + def set_use_memory_efficient_attention_xformers(self, value): + for attn in self.attentions: + attn.set_use_memory_efficient_attention_xformers(value) + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () @@ -714,11 +759,37 @@ class UNetMidBlock2DCrossAttn(nn.Module): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention_xformers(self, value): + for attn in self.attentions: + attn.set_use_memory_efficient_attention_xformers(value) + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states).sample - hidden_states = resnet(hidden_states, temb) + for i, resnet in enumerate(self.resnets): + attn = None if i == 0 else self.attentions[i - 1] + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + if attn is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + if attn is not None: + hidden_states = attn(hidden_states, encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -792,6 +863,9 @@ class UpBlock2D(nn.Module): self.gradient_checkpointing = False + def set_use_memory_efficient_attention_xformers(self, value): + pass + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet in self.resnets: # pop res hidden states @@ -868,6 +942,10 @@ class CrossAttnUpBlock2D(nn.Module): self.gradient_checkpointing = False + def set_use_memory_efficient_attention_xformers(self, value): + for attn in self.attentions: + attn.set_use_memory_efficient_attention_xformers(value) + def forward( self, hidden_states, @@ -991,6 +1069,8 @@ class UNet2DConditionModel(nn.Module): self.sample_size = sample_size + # state_dictの書式が変わるのでmoduleの持ち方は変えられない + # input self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1)) @@ -1069,7 +1149,7 @@ class UNet2DConditionModel(nn.Module): prev_output_channel = output_channel # out - self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=RESNET_GROUPS, eps=RESNET_EPS) + self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) @@ -1088,16 +1168,20 @@ class UNet2DConditionModel(nn.Module): return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) def enable_gradient_checkpointing(self): - self._set_gradient_checkpointing(self, value=True) + self.set_gradient_checkpointing(value=True) def disable_gradient_checkpointing(self): - self._set_gradient_checkpointing(self, value=False) + self.set_gradient_checkpointing(value=False) def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: - raise NotImplementedError("Memory efficient attention is not supported for this model.") + modules = self.down_blocks + [self.mid_block] + self.up_blocks + for module in modules: + module.set_use_memory_efficient_attention_xformers(valid) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): + def set_gradient_checkpointing(self, value=False): + modules = self.down_blocks + [self.mid_block] + self.up_blocks + for module in modules: + print(module.__class__.__name__, module.gradient_checkpointing, "->", value) module.gradient_checkpointing = value # endregion diff --git a/library/train_util.py b/library/train_util.py index 844faca7..b7cee937 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1792,7 +1792,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio if mem_eff_attn: replace_unet_cross_attn_to_memory_efficient() elif xformers: - replace_unet_cross_attn_to_xformers() + replace_unet_cross_attn_to_xformers(unet) def replace_unet_cross_attn_to_memory_efficient(): @@ -1827,55 +1827,59 @@ def replace_unet_cross_attn_to_memory_efficient(): out = rearrange(out, "b h n d -> b n (h d)") - # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) out = self.to_out[0](out) - out = self.to_out[1](out) + # out = self.to_out[1](out) return out - diffusers.models.attention.CrossAttention.forward = forward_flash_attn + # diffusers.models.attention.CrossAttention.forward = forward_flash_attn + from library.original_unet import CrossAttention + + CrossAttention.forward = forward_flash_attn -def replace_unet_cross_attn_to_xformers(): +def replace_unet_cross_attn_to_xformers(unet): print("CrossAttention.forward has been replaced to enable xformers.") try: import xformers.ops except ImportError: raise ImportError("No xformers / xformersがインストールされていないようです") - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) + unet.set_use_memory_efficient_attention_xformers(True) - context = default(context, x) - context = context.to(x.dtype) + # def forward_xformers(self, x, context=None, mask=None): + # h = self.heads + # q_in = self.to_q(x) - if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + # context = default(context, x) + # context = context.to(x.dtype) - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) + # if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + # context_k, context_v = self.hypernetwork.forward(x, context) + # context_k = context_k.to(x.dtype) + # context_v = context_v.to(x.dtype) + # else: + # context_k = context + # context_v = context - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in + # k_in = self.to_k(context_k) + # v_in = self.to_v(context_v) - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + # q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + # del q_in, k_in, v_in - out = rearrange(out, "b n h d -> b n (h d)", h=h) + # q = q.contiguous() + # k = k.contiguous() + # v = v.contiguous() + # out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out + # out = rearrange(out, "b n h d -> b n (h d)", h=h) - diffusers.models.attention.CrossAttention.forward = forward_xformers + # # diffusers 0.7.0~ + # out = self.to_out[0](out) + # out = self.to_out[1](out) + # return out + + # diffusers.models.attention.CrossAttention.forward = forward_xformers """