diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 99e94cae..27bd7460 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -311,6 +311,7 @@ class FlashAttentionFunction(torch.autograd.Function): return dq, dk, dv, None, None, None, None +# TODO common train_util.py def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): if mem_eff_attn: replace_unet_cross_attn_to_memory_efficient() @@ -319,7 +320,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio def replace_unet_cross_attn_to_memory_efficient(): - print("Replace CrossAttention.forward to use NAI style Hypernetwork and FlashAttention") + print("CrossAttention.forward has been replaced to FlashAttention (not xformers) and NAI style Hypernetwork") flash_func = FlashAttentionFunction def forward_flash_attn(self, x, context=None, mask=None): @@ -359,7 +360,7 @@ def replace_unet_cross_attn_to_memory_efficient(): def replace_unet_cross_attn_to_xformers(): - print("Replace CrossAttention.forward to use NAI style Hypernetwork and xformers") + print("CrossAttention.forward has been replaced to enable xformers and NAI style Hypernetwork") try: import xformers.ops except ImportError: @@ -401,6 +402,104 @@ def replace_unet_cross_attn_to_xformers(): diffusers.models.attention.CrossAttention.forward = forward_xformers +def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): + if mem_eff_attn: + replace_vae_attn_to_memory_efficient() + elif xformers: + # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ + print("Use Diffusers xformers for VAE") + vae.set_use_memory_efficient_attention_xformers(True) + + """ + # VAEがbfloat16でメモリ消費が大きい問題を解決する + upsamplers = [] + for block in vae.decoder.up_blocks: + if block.upsamplers is not None: + upsamplers.extend(block.upsamplers) + + def forward_upsample(_self, hidden_states, output_size=None): + assert hidden_states.shape[1] == _self.channels + if _self.use_conv_transpose: + return _self.conv(hidden_states) + + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + assert output_size is None + # repeat_interleaveはすごく遅いが、回数はあまり呼ばれないので許容する + hidden_states = hidden_states.repeat_interleave(2, dim=-1) + hidden_states = hidden_states.repeat_interleave(2, dim=-2) + else: + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = torch.nn.functional.interpolate(hidden_states, size=output_size, mode="nearest") + + if _self.use_conv: + if _self.name == "conv": + hidden_states = _self.conv(hidden_states) + else: + hidden_states = _self.Conv2d_0(hidden_states) + return hidden_states + + # replace upsamplers + for upsampler in upsamplers: + # make new scope + def make_replacer(upsampler): + def forward(hidden_states, output_size=None): + return forward_upsample(upsampler, hidden_states, output_size) + + return forward + + upsampler.forward = make_replacer(upsampler) +""" + + +def replace_vae_attn_to_memory_efficient(): + print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, hidden_states): + print("forward_flash_attn") + q_bucket_size = 512 + k_bucket_size = 1024 + + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj) + ) + + out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + diffusers.models.attention.AttentionBlock.forward = forward_flash_attn + + # endregion # region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 @@ -2142,6 +2241,7 @@ def main(args): # xformers、Hypernetwork対応 if not args.diffusers_xformers: replace_unet_modules(unet, not args.xformers, args.xformers) + replace_vae_modules(vae, not args.xformers, args.xformers) # tokenizerを読み込む print("loading tokenizer") @@ -3175,8 +3275,7 @@ def setup_parser() -> argparse.ArgumentParser: "--vae_slices", type=int, default=None, - help= - "number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨" + help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", ) parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") parser.add_argument( diff --git a/library/train_util.py b/library/train_util.py index 4349345a..41afc13b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1765,6 +1765,7 @@ class FlashAttentionFunction(torch.autograd.function.Function): def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): + # unet is not used currently, but it is here for future use if mem_eff_attn: replace_unet_cross_attn_to_memory_efficient() elif xformers: @@ -1772,7 +1773,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio def replace_unet_cross_attn_to_memory_efficient(): - print("Replace CrossAttention.forward to use FlashAttention (not xformers)") + print("CrossAttention.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, x, context=None, mask=None): @@ -1812,7 +1813,7 @@ def replace_unet_cross_attn_to_memory_efficient(): def replace_unet_cross_attn_to_xformers(): - print("Replace CrossAttention.forward to use xformers") + print("CrossAttention.forward has been replaced to enable xformers.") try: import xformers.ops except ImportError: @@ -1854,6 +1855,60 @@ def replace_unet_cross_attn_to_xformers(): diffusers.models.attention.CrossAttention.forward = forward_xformers +""" +def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): + # vae is not used currently, but it is here for future use + if mem_eff_attn: + replace_vae_attn_to_memory_efficient() + elif xformers: + # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ + print("Use Diffusers xformers for VAE") + vae.encoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) + vae.decoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) + + +def replace_vae_attn_to_memory_efficient(): + print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, hidden_states): + print("forward_flash_attn") + q_bucket_size = 512 + k_bucket_size = 1024 + + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj) + ) + + out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + diffusers.models.attention.AttentionBlock.forward = forward_flash_attn +""" + + # endregion