From 4e25c8f78e873224e753af6d0509bc9281acff65 Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 11 Jun 2023 16:57:17 +0900 Subject: [PATCH] fix to work with Diffusers 0.17.0 --- gen_img_diffusers.py | 399 ++++++-------------------------- library/lpw_stable_diffusion.py | 4 +- library/model_util.py | 24 +- library/train_util.py | 2 +- 4 files changed, 86 insertions(+), 343 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 28f7323a..34857af3 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -82,7 +82,6 @@ from diffusers import ( StableDiffusionPipeline, ) from einops import rearrange -from torch import einsum from tqdm import tqdm from torchvision import transforms from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig @@ -96,6 +95,7 @@ from networks.lora import LoRANetwork import tools.original_control_net as original_control_net from tools.original_control_net import ControlNetInfo from library.original_unet import UNet2DConditionModel +from library.original_unet import FlashAttentionFunction from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI @@ -136,341 +136,36 @@ USE_CUTOUTS = False 高速化のためのモジュール入れ替え """ -# FlashAttentionを使うCrossAttention -# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py -# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE -# constants +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): + if mem_eff_attn: + print("Enable memory efficient attention for U-Net") -EPSILON = 1e-6 + # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い + unet.set_use_memory_efficient_attention(False, True) + elif xformers: + print("Enable xformers for U-Net") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") -# helper functions - - -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - - -# flash attention forwards and backwards - -# https://arxiv.org/abs/2205.14135 - - -class FlashAttentionFunction(torch.autograd.Function): - @staticmethod - @torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """Algorithm 2 in the paper""" - - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - - scale = q.shape[-1] ** -0.5 - - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, "b n -> b 1 1 n") - mask = mask.split(q_bucket_size, dim=-1) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( - q_start_index - k_start_index + 1 - ) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) - - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.0) - - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - - exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc) - - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) - - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - - return o - - @staticmethod - @torch.no_grad() - def backward(ctx, do): - """Algorithm 4 in the paper""" - - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors - - device = q.device - - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( - q_start_index - k_start_index + 1 - ) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - exp_attn_weights = torch.exp(attn_weights - mc) - - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.0) - - p = exp_attn_weights / lc - - dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) - dp = einsum("... i d, ... j d -> ... i j", doc, vc) - - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) - - dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) - dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) - - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) - - return dq, dk, dv, None, None, None, None + unet.set_use_memory_efficient_attention(True, False) # TODO common train_util.py -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): - if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() - elif xformers: - replace_unet_cross_attn_to_xformers(unet) - - -def replace_unet_cross_attn_to_memory_efficient(): - print("CrossAttention.forward has been replaced to FlashAttention (not xformers) and NAI style Hypernetwork") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 - - h = self.heads - q = self.to_q(x) - - context = context if context is not None else x - context = context.to(x.dtype) - - if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context - - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, "b h n d -> b n (h d)") - - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - # diffusers.models.attention.CrossAttention.forward = forward_flash_attn - # TODO U-Net側に移す - from library.original_unet import CrossAttention - CrossAttention.forward = forward_flash_attn - - -def replace_unet_cross_attn_to_xformers(unet:UNet2DConditionModel): - print("CrossAttention.forward has been replaced to enable xformers and NAI style Hypernetwork") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") - - unet.set_use_memory_efficient_attention_xformers(True) - - # def forward_xformers(self, x, context=None, mask=None): - # h = self.heads - # q_in = self.to_q(x) - - # context = default(context, x) - # context = context.to(x.dtype) - - # if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - # context_k, context_v = self.hypernetwork.forward(x, context) - # context_k = context_k.to(x.dtype) - # context_v = context_v.to(x.dtype) - # else: - # context_k = context - # context_v = context - - # k_in = self.to_k(context_k) - # v_in = self.to_v(context_v) - - # q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) - # del q_in, k_in, v_in - - # q = q.contiguous() - # k = k.contiguous() - # v = v.contiguous() - # out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - - # out = rearrange(out, "b n h d -> b n (h d)", h=h) - - # # diffusers 0.7.0~ - # out = self.to_out[0](out) - # out = self.to_out[1](out) - # return out - - # diffusers.models.attention.CrossAttention.forward = forward_xformers - - def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): if mem_eff_attn: replace_vae_attn_to_memory_efficient() elif xformers: # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ - print("Use Diffusers xformers for VAE") - vae.set_use_memory_efficient_attention_xformers(True) - - """ - # VAEがbfloat16でメモリ消費が大きい問題を解決する - upsamplers = [] - for block in vae.decoder.up_blocks: - if block.upsamplers is not None: - upsamplers.extend(block.upsamplers) - - def forward_upsample(_self, hidden_states, output_size=None): - assert hidden_states.shape[1] == _self.channels - if _self.use_conv_transpose: - return _self.conv(hidden_states) - - dtype = hidden_states.dtype - if dtype == torch.bfloat16: - assert output_size is None - # repeat_interleaveはすごく遅いが、回数はあまり呼ばれないので許容する - hidden_states = hidden_states.repeat_interleave(2, dim=-1) - hidden_states = hidden_states.repeat_interleave(2, dim=-2) - else: - if hidden_states.shape[0] >= 64: - hidden_states = hidden_states.contiguous() - - # if `output_size` is passed we force the interpolation output - # size and do not make use of `scale_factor=2` - if output_size is None: - hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest") - else: - hidden_states = torch.nn.functional.interpolate(hidden_states, size=output_size, mode="nearest") - - if _self.use_conv: - if _self.name == "conv": - hidden_states = _self.conv(hidden_states) - else: - hidden_states = _self.Conv2d_0(hidden_states) - return hidden_states - - # replace upsamplers - for upsampler in upsamplers: - # make new scope - def make_replacer(upsampler): - def forward(hidden_states, output_size=None): - return forward_upsample(upsampler, hidden_states, output_size) - - return forward - - upsampler.forward = make_replacer(upsampler) -""" - + replace_vae_attn_to_xformers() def replace_vae_attn_to_memory_efficient(): - print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") - flash_func = FlashAttentionFunction + print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + flash_func =FlashAttentionFunction - def forward_flash_attn(self, hidden_states): - print("forward_flash_attn") + def forward_flash_attn(self, hidden_states, **kwargs): q_bucket_size = 512 k_bucket_size = 1024 @@ -483,12 +178,12 @@ def replace_vae_attn_to_memory_efficient(): hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj) + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) ) out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) @@ -496,14 +191,62 @@ def replace_vae_attn_to_memory_efficient(): out = rearrange(out, "b h n d -> b n (h d)") # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) # res connect and rescale hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states - diffusers.models.attention.AttentionBlock.forward = forward_flash_attn + diffusers.models.attention_processor.Attention.forward = forward_flash_attn + + +def replace_vae_attn_to_xformers(): + print("VAE: Attention.forward has been replaced to xformers") + import xformers.ops + + def forward_xformers(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + query_proj = query_proj.contiguous() + key_proj = key_proj.contiguous() + value_proj = value_proj.contiguous() + out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + diffusers.models.attention_processor.Attention.forward = forward_xformers # endregion diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 3e04b887..883707f7 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -464,7 +464,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: SchedulerMixin, - clip_skip: int, + # clip_skip: int, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, @@ -479,7 +479,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): feature_extractor=feature_extractor, requires_safety_checker=requires_safety_checker, ) - self.clip_skip = clip_skip + # self.clip_skip = clip_skip self.__init__additional__() # else: diff --git a/library/model_util.py b/library/model_util.py index 0fbc6590..ea1be513 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -5,7 +5,7 @@ import math import os import torch from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging -from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline #, UNet2DConditionModel +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel from safetensors.torch import load_file, save_file from library.original_unet import UNet2DConditionModel @@ -127,17 +127,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.bias", "group_norm.bias") - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) @@ -192,8 +192,8 @@ def assign_to_checkpoint( new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] else: checkpoint[new_path] = old_checkpoint[path["old"]] @@ -362,7 +362,7 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config): # SDのv2では1*1のconv2dがlinearに変わっている # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要 - if v2 and not config.get('use_linear_projection', False): + if v2 and not config.get("use_linear_projection", False): linear_transformer_to_conv(new_checkpoint) return new_checkpoint diff --git a/library/train_util.py b/library/train_util.py index b7cee937..8aa7f987 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3467,11 +3467,11 @@ def sample_images( unet=unet, tokenizer=tokenizer, scheduler=scheduler, - clip_skip=args.clip_skip, safety_checker=None, feature_extractor=None, requires_safety_checker=False, ) + pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する pipeline.to(device) save_dir = args.output_dir + "/sample"