fix to work with Diffusers 0.17.0

This commit is contained in:
ykume
2023-06-11 16:57:17 +09:00
parent 7f6b581ef8
commit 4e25c8f78e
4 changed files with 86 additions and 343 deletions

View File

@@ -82,7 +82,6 @@ from diffusers import (
StableDiffusionPipeline, StableDiffusionPipeline,
) )
from einops import rearrange from einops import rearrange
from torch import einsum
from tqdm import tqdm from tqdm import tqdm
from torchvision import transforms from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig
@@ -96,6 +95,7 @@ from networks.lora import LoRANetwork
import tools.original_control_net as original_control_net import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo from tools.original_control_net import ControlNetInfo
from library.original_unet import UNet2DConditionModel from library.original_unet import UNet2DConditionModel
from library.original_unet import FlashAttentionFunction
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
@@ -136,341 +136,36 @@ USE_CUTOUTS = False
高速化のためのモジュール入れ替え 高速化のためのモジュール入れ替え
""" """
# FlashAttentionを使うCrossAttention
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
# constants
EPSILON = 1e-6
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# flash attention forwards and backwards
# https://arxiv.org/abs/2205.14135
class FlashAttentionFunction(torch.autograd.Function):
@staticmethod
@torch.no_grad()
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
"""Algorithm 2 in the paper"""
device = q.device
dtype = q.dtype
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
o = torch.zeros_like(q)
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
scale = q.shape[-1] ** -0.5
if not exists(mask):
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
else:
mask = rearrange(mask, "b n -> b 1 1 n")
mask = mask.split(q_bucket_size, dim=-1)
row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
mask,
all_row_sums.split(q_bucket_size, dim=-2),
all_row_maxes.split(q_bucket_size, dim=-2),
)
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
)
for k_ind, (kc, vc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
if exists(row_mask):
attn_weights.masked_fill_(~row_mask, max_neg_value)
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
q_start_index - k_start_index + 1
)
attn_weights.masked_fill_(causal_mask, max_neg_value)
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
attn_weights -= block_row_maxes
exp_weights = torch.exp(attn_weights)
if exists(row_mask):
exp_weights.masked_fill_(~row_mask, 0.0)
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc)
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
row_maxes.copy_(new_row_maxes)
row_sums.copy_(new_row_sums)
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
return o
@staticmethod
@torch.no_grad()
def backward(ctx, do):
"""Algorithm 4 in the paper"""
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
q, k, v, o, l, m = ctx.saved_tensors
device = q.device
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
do.split(q_bucket_size, dim=-2),
mask,
l.split(q_bucket_size, dim=-2),
m.split(q_bucket_size, dim=-2),
dq.split(q_bucket_size, dim=-2),
)
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
dk.split(k_bucket_size, dim=-2),
dv.split(k_bucket_size, dim=-2),
)
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
q_start_index - k_start_index + 1
)
attn_weights.masked_fill_(causal_mask, max_neg_value)
exp_attn_weights = torch.exp(attn_weights - mc)
if exists(row_mask):
exp_attn_weights.masked_fill_(~row_mask, 0.0)
p = exp_attn_weights / lc
dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc)
dp = einsum("... i d, ... j d -> ... i j", doc, vc)
D = (doc * oc).sum(dim=-1, keepdims=True)
ds = p * scale * (dp - D)
dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc)
dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc)
dqc.add_(dq_chunk)
dkc.add_(dk_chunk)
dvc.add_(dv_chunk)
return dq, dk, dv, None, None, None, None
# TODO common train_util.py
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
if mem_eff_attn: if mem_eff_attn:
replace_unet_cross_attn_to_memory_efficient() print("Enable memory efficient attention for U-Net")
# これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い
unet.set_use_memory_efficient_attention(False, True)
elif xformers: elif xformers:
replace_unet_cross_attn_to_xformers(unet) print("Enable xformers for U-Net")
def replace_unet_cross_attn_to_memory_efficient():
print("CrossAttention.forward has been replaced to FlashAttention (not xformers) and NAI style Hypernetwork")
flash_func = FlashAttentionFunction
def forward_flash_attn(self, x, context=None, mask=None):
q_bucket_size = 512
k_bucket_size = 1024
h = self.heads
q = self.to_q(x)
context = context if context is not None else x
context = context.to(x.dtype)
if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
context_k, context_v = self.hypernetwork.forward(x, context)
context_k = context_k.to(x.dtype)
context_v = context_v.to(x.dtype)
else:
context_k = context
context_v = context
k = self.to_k(context_k)
v = self.to_v(context_v)
del context, x
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
out = rearrange(out, "b h n d -> b n (h d)")
# diffusers 0.7.0~
out = self.to_out[0](out)
out = self.to_out[1](out)
return out
# diffusers.models.attention.CrossAttention.forward = forward_flash_attn
# TODO U-Net側に移す
from library.original_unet import CrossAttention
CrossAttention.forward = forward_flash_attn
def replace_unet_cross_attn_to_xformers(unet:UNet2DConditionModel):
print("CrossAttention.forward has been replaced to enable xformers and NAI style Hypernetwork")
try: try:
import xformers.ops import xformers.ops
except ImportError: except ImportError:
raise ImportError("No xformers / xformersがインストールされていないようです") raise ImportError("No xformers / xformersがインストールされていないようです")
unet.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention(True, False)
# def forward_xformers(self, x, context=None, mask=None):
# h = self.heads
# q_in = self.to_q(x)
# context = default(context, x)
# context = context.to(x.dtype)
# if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
# context_k, context_v = self.hypernetwork.forward(x, context)
# context_k = context_k.to(x.dtype)
# context_v = context_v.to(x.dtype)
# else:
# context_k = context
# context_v = context
# k_in = self.to_k(context_k)
# v_in = self.to_v(context_v)
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
# del q_in, k_in, v_in
# q = q.contiguous()
# k = k.contiguous()
# v = v.contiguous()
# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
# out = rearrange(out, "b n h d -> b n (h d)", h=h)
# # diffusers 0.7.0~
# out = self.to_out[0](out)
# out = self.to_out[1](out)
# return out
# diffusers.models.attention.CrossAttention.forward = forward_xformers
# TODO common train_util.py
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
if mem_eff_attn: if mem_eff_attn:
replace_vae_attn_to_memory_efficient() replace_vae_attn_to_memory_efficient()
elif xformers: elif xformers:
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
print("Use Diffusers xformers for VAE") replace_vae_attn_to_xformers()
vae.set_use_memory_efficient_attention_xformers(True)
"""
# VAEがbfloat16でメモリ消費が大きい問題を解決する
upsamplers = []
for block in vae.decoder.up_blocks:
if block.upsamplers is not None:
upsamplers.extend(block.upsamplers)
def forward_upsample(_self, hidden_states, output_size=None):
assert hidden_states.shape[1] == _self.channels
if _self.use_conv_transpose:
return _self.conv(hidden_states)
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
assert output_size is None
# repeat_interleaveはすごく遅いが、回数はあまり呼ばれないので許容する
hidden_states = hidden_states.repeat_interleave(2, dim=-1)
hidden_states = hidden_states.repeat_interleave(2, dim=-2)
else:
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
else:
hidden_states = torch.nn.functional.interpolate(hidden_states, size=output_size, mode="nearest")
if _self.use_conv:
if _self.name == "conv":
hidden_states = _self.conv(hidden_states)
else:
hidden_states = _self.Conv2d_0(hidden_states)
return hidden_states
# replace upsamplers
for upsampler in upsamplers:
# make new scope
def make_replacer(upsampler):
def forward(hidden_states, output_size=None):
return forward_upsample(upsampler, hidden_states, output_size)
return forward
upsampler.forward = make_replacer(upsampler)
"""
def replace_vae_attn_to_memory_efficient(): def replace_vae_attn_to_memory_efficient():
print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") print("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
flash_func = FlashAttentionFunction flash_func =FlashAttentionFunction
def forward_flash_attn(self, hidden_states): def forward_flash_attn(self, hidden_states, **kwargs):
print("forward_flash_attn")
q_bucket_size = 512 q_bucket_size = 512
k_bucket_size = 1024 k_bucket_size = 1024
@@ -483,12 +178,12 @@ def replace_vae_attn_to_memory_efficient():
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v # proj to q, k, v
query_proj = self.query(hidden_states) query_proj = self.to_q(hidden_states)
key_proj = self.key(hidden_states) key_proj = self.to_k(hidden_states)
value_proj = self.value(hidden_states) value_proj = self.to_v(hidden_states)
query_proj, key_proj, value_proj = map( query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj) lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj)
) )
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
@@ -496,14 +191,62 @@ def replace_vae_attn_to_memory_efficient():
out = rearrange(out, "b h n d -> b n (h d)") out = rearrange(out, "b h n d -> b n (h d)")
# compute next hidden_states # compute next hidden_states
hidden_states = self.proj_attn(hidden_states) # linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale # res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states return hidden_states
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn diffusers.models.attention_processor.Attention.forward = forward_flash_attn
def replace_vae_attn_to_xformers():
print("VAE: Attention.forward has been replaced to xformers")
import xformers.ops
def forward_xformers(self, hidden_states, **kwargs):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.to_q(hidden_states)
key_proj = self.to_k(hidden_states)
value_proj = self.to_v(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj)
)
query_proj = query_proj.contiguous()
key_proj = key_proj.contiguous()
value_proj = value_proj.contiguous()
out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
out = rearrange(out, "b h n d -> b n (h d)")
# compute next hidden_states
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
diffusers.models.attention_processor.Attention.forward = forward_xformers
# endregion # endregion

View File

@@ -464,7 +464,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: SchedulerMixin, scheduler: SchedulerMixin,
clip_skip: int, # clip_skip: int,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
@@ -479,7 +479,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
requires_safety_checker=requires_safety_checker, requires_safety_checker=requires_safety_checker,
) )
self.clip_skip = clip_skip # self.clip_skip = clip_skip
self.__init__additional__() self.__init__additional__()
# else: # else:

View File

@@ -5,7 +5,7 @@ import math
import os import os
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline #, UNet2DConditionModel from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from library.original_unet import UNet2DConditionModel from library.original_unet import UNet2DConditionModel
@@ -127,17 +127,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias") new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "query.weight") new_item = new_item.replace("q.weight", "to_q.weight")
new_item = new_item.replace("q.bias", "query.bias") new_item = new_item.replace("q.bias", "to_q.bias")
new_item = new_item.replace("k.weight", "key.weight") new_item = new_item.replace("k.weight", "to_k.weight")
new_item = new_item.replace("k.bias", "key.bias") new_item = new_item.replace("k.bias", "to_k.bias")
new_item = new_item.replace("v.weight", "value.weight") new_item = new_item.replace("v.weight", "to_v.weight")
new_item = new_item.replace("v.bias", "value.bias") new_item = new_item.replace("v.bias", "to_v.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight") new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias") new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
@@ -192,8 +192,8 @@ def assign_to_checkpoint(
new_path = new_path.replace(replacement["old"], replacement["new"]) new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear # proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path: if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
else: else:
checkpoint[new_path] = old_checkpoint[path["old"]] checkpoint[new_path] = old_checkpoint[path["old"]]
@@ -362,7 +362,7 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
# SDのv2では1*1のconv2dがlinearに変わっている # SDのv2では1*1のconv2dがlinearに変わっている
# 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要 # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
if v2 and not config.get('use_linear_projection', False): if v2 and not config.get("use_linear_projection", False):
linear_transformer_to_conv(new_checkpoint) linear_transformer_to_conv(new_checkpoint)
return new_checkpoint return new_checkpoint

View File

@@ -3467,11 +3467,11 @@ def sample_images(
unet=unet, unet=unet,
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=scheduler, scheduler=scheduler,
clip_skip=args.clip_skip,
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
requires_safety_checker=False, requires_safety_checker=False,
) )
pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する
pipeline.to(device) pipeline.to(device)
save_dir = args.output_dir + "/sample" save_dir = args.output_dir + "/sample"