diff --git a/README.md b/README.md index 8234a89e..634b9494 100644 --- a/README.md +++ b/README.md @@ -75,8 +75,6 @@ cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_set accelerate config ``` -update: ``python -m venv venv`` is seemed to be safer than ``python -m venv --system-site-packages venv`` (some user have packages in global python). - Answers to accelerate config: ```txt @@ -94,6 +92,30 @@ note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o (Single GPU with id `0` will be used.) +### Experimental: Use PyTorch 2.0 + +In this case, you need to install PyTorch 2.0 and xformers 0.0.20. Instead of the above, please type the following: + +```powershell +git clone https://github.com/kohya-ss/sd-scripts.git +cd sd-scripts + +python -m venv venv +.\venv\Scripts\activate + +pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +pip install --upgrade -r requirements.txt +pip install xformers==0.0.20 + +cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ +cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py +cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py + +accelerate config +``` + +Answers to accelerate config should be the same as above. + ### about PyTorch and xformers Other versions of PyTorch and xformers seem to have problems with training. diff --git a/fine_tune.py b/fine_tune.py index 881845c5..120f3d0f 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -141,7 +141,7 @@ def train(args): # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある print("Disable Diffusers' xformers") set_diffusers_xformers_flag(unet, False) - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) # 学習を準備する if cache_latents: diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 71daa9a1..889b4c4c 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -137,7 +137,7 @@ USE_CUTOUTS = False """ -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: print("Enable memory efficient attention for U-Net") @@ -151,56 +151,26 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio raise ImportError("No xformers / xformersがインストールされていないようです") unet.set_use_memory_efficient_attention(True, False) + elif sdpa: + print("Enable SDPA for U-Net") + unet.set_use_memory_efficient_attention(False, False) + unet.set_use_sdpa(True) # TODO common train_util.py -def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): +def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): if mem_eff_attn: replace_vae_attn_to_memory_efficient() elif xformers: - # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ replace_vae_attn_to_xformers() + elif sdpa: + replace_vae_attn_to_sdpa() def replace_vae_attn_to_memory_efficient(): print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction - def forward_flash_attn_0_14(self, hidden_states, **kwargs): - q_bucket_size = 512 - k_bucket_size = 1024 - - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj) - ) - - out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, "b h n d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.proj_attn(hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - def forward_flash_attn(self, hidden_states, **kwargs): q_bucket_size = 512 k_bucket_size = 1024 @@ -238,6 +208,15 @@ def replace_vae_attn_to_memory_efficient(): hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states + def forward_flash_attn_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_flash_attn(self, hidden_states, **kwargs) + if diffusers.__version__ < "0.15.0": diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 else: @@ -248,40 +227,6 @@ def replace_vae_attn_to_xformers(): print("VAE: Attention.forward has been replaced to xformers") import xformers.ops - def forward_xformers_0_14(self, hidden_states, **kwargs): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj) - ) - - query_proj = query_proj.contiguous() - key_proj = key_proj.contiguous() - value_proj = value_proj.contiguous() - out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) - - out = rearrange(out, "b h n d -> b n (h d)") - - # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - def forward_xformers(self, hidden_states, **kwargs): residual = hidden_states batch, channel, height, width = hidden_states.shape @@ -319,12 +264,75 @@ def replace_vae_attn_to_xformers(): hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states + def forward_xformers_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_xformers(self, hidden_states, **kwargs) + if diffusers.__version__ < "0.15.0": diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 else: diffusers.models.attention_processor.Attention.forward = forward_xformers +def replace_vae_attn_to_sdpa(): + print("VAE: Attention.forward has been replaced to sdpa") + + def forward_sdpa(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = torch.nn.functional.scaled_dot_product_attention( + query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + out = rearrange(out, "b n h d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_sdpa_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_sdpa(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_sdpa + + # endregion # region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 @@ -2082,8 +2090,9 @@ def main(args): # xformers、Hypernetwork対応 if not args.diffusers_xformers: - replace_unet_modules(unet, not args.xformers, args.xformers) - replace_vae_modules(vae, not args.xformers, args.xformers) + mem_eff = not (args.xformers or args.sdpa) + replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) + replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) # tokenizerを読み込む print("loading tokenizer") @@ -3176,6 +3185,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") + parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") parser.add_argument( "--diffusers_xformers", action="store_true", diff --git a/library/original_unet.py b/library/original_unet.py index 36318eb9..e22b16c0 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -494,6 +494,9 @@ class DownBlock2D(nn.Module): def set_use_memory_efficient_attention(self, xformers, mem_eff): pass + def set_use_sdpa(self, sdpa): + pass + def forward(self, hidden_states, temb=None): output_states = () @@ -564,11 +567,15 @@ class CrossAttention(nn.Module): self.use_memory_efficient_attention_xformers = False self.use_memory_efficient_attention_mem_eff = False + self.use_sdpa = False def set_use_memory_efficient_attention(self, xformers, mem_eff): self.use_memory_efficient_attention_xformers = xformers self.use_memory_efficient_attention_mem_eff = mem_eff + def set_use_sdpa(self, sdpa): + self.use_sdpa = sdpa + def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads @@ -588,6 +595,8 @@ class CrossAttention(nn.Module): return self.forward_memory_efficient_xformers(hidden_states, context, mask) if self.use_memory_efficient_attention_mem_eff: return self.forward_memory_efficient_mem_eff(hidden_states, context, mask) + if self.use_sdpa: + return self.forward_sdpa(hidden_states, context, mask) query = self.to_q(hidden_states) context = context if context is not None else hidden_states @@ -676,6 +685,26 @@ class CrossAttention(nn.Module): out = self.to_out[0](out) return out + def forward_sdpa(self, x, context=None, mask=None): + import xformers.ops + + h = self.heads + q_in = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k_in = self.to_k(context) + v_in = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + + out = rearrange(out, "b h n d -> b n (h d)", h=h) + + out = self.to_out[0](out) + return out + # feedforward class GEGLU(nn.Module): @@ -759,6 +788,10 @@ class BasicTransformerBlock(nn.Module): self.attn1.set_use_memory_efficient_attention(xformers, mem_eff) self.attn2.set_use_memory_efficient_attention(xformers, mem_eff) + def set_use_sdpa(self, sdpa: bool): + self.attn1.set_use_sdpa(sdpa) + self.attn2.set_use_sdpa(sdpa) + def forward(self, hidden_states, context=None, timestep=None): # 1. Self-Attention norm_hidden_states = self.norm1(hidden_states) @@ -820,6 +853,10 @@ class Transformer2DModel(nn.Module): for transformer in self.transformer_blocks: transformer.set_use_memory_efficient_attention(xformers, mem_eff) + def set_use_sdpa(self, sdpa): + for transformer in self.transformer_blocks: + transformer.set_use_sdpa(sdpa) + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): # 1. Input batch, _, height, weight = hidden_states.shape @@ -901,6 +938,10 @@ class CrossAttnDownBlock2D(nn.Module): for attn in self.attentions: attn.set_use_memory_efficient_attention(xformers, mem_eff) + def set_use_sdpa(self, sdpa): + for attn in self.attentions: + attn.set_use_sdpa(sdpa) + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () @@ -978,6 +1019,10 @@ class UNetMidBlock2DCrossAttn(nn.Module): for attn in self.attentions: attn.set_use_memory_efficient_attention(xformers, mem_eff) + def set_use_sdpa(self, sdpa): + for attn in self.attentions: + attn.set_use_sdpa(sdpa) + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): for i, resnet in enumerate(self.resnets): attn = None if i == 0 else self.attentions[i - 1] @@ -1079,6 +1124,9 @@ class UpBlock2D(nn.Module): def set_use_memory_efficient_attention(self, xformers, mem_eff): pass + def set_use_sdpa(self, sdpa): + pass + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet in self.resnets: # pop res hidden states @@ -1159,6 +1207,10 @@ class CrossAttnUpBlock2D(nn.Module): for attn in self.attentions: attn.set_use_memory_efficient_attention(xformers, mem_eff) + def set_use_sdpa(self, spda): + for attn in self.attentions: + attn.set_use_sdpa(spda) + def forward( self, hidden_states, @@ -1393,10 +1445,15 @@ class UNet2DConditionModel(nn.Module): def disable_gradient_checkpointing(self): self.set_gradient_checkpointing(value=False) - def set_use_memory_efficient_attention(self, xformers: bool,mem_eff:bool) -> None: + def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None: modules = self.down_blocks + [self.mid_block] + self.up_blocks for module in modules: - module.set_use_memory_efficient_attention(xformers,mem_eff) + module.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa: bool) -> None: + modules = self.down_blocks + [self.mid_block] + self.up_blocks + for module in modules: + module.set_use_sdpa(sdpa) def set_gradient_checkpointing(self, value=False): modules = self.down_blocks + [self.mid_block] + self.up_blocks diff --git a/library/train_util.py b/library/train_util.py index 3ae5d0f3..30380262 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1788,7 +1788,7 @@ class FlashAttentionFunction(torch.autograd.function.Function): return dq, dk, dv, None, None, None, None -def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers): +def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: print("Enable memory efficient attention for U-Net") unet.set_use_memory_efficient_attention(False, True) @@ -1800,6 +1800,9 @@ def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers): raise ImportError("No xformers / xformersがインストールされていないようです") unet.set_use_memory_efficient_attention(True, False) + elif sdpa: + print("Enable SDPA for U-Net") + unet.set_use_sdpa(True) """ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): @@ -2048,6 +2051,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") + parser.add_argument("--sdpa", action="store_true", help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)") parser.add_argument( "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" ) diff --git a/train_db.py b/train_db.py index 09f8d361..895b8b24 100644 --- a/train_db.py +++ b/train_db.py @@ -119,7 +119,7 @@ def train(args): use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) # 学習を準備する if cache_latents: diff --git a/train_network.py b/train_network.py index 7c74ae5d..9ea9bf9c 100644 --- a/train_network.py +++ b/train_network.py @@ -160,7 +160,7 @@ def train(args): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) # 差分追加学習のためにモデルを読み込む import sys diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 9dd846bd..4f31220d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -231,7 +231,7 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) # 学習を準備する if cache_latents: diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 1ea6dfc6..69f618cc 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -264,7 +264,7 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) original_unet.UNet2DConditionModel.forward = unet_forward_XTI original_unet.CrossAttnDownBlock2D.forward = downblock_forward_XTI original_unet.CrossAttnUpBlock2D.forward = upblock_forward_XTI