mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix to work with Diffusers 0.17.0
This commit is contained in:
@@ -82,7 +82,6 @@ from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from einops import rearrange
|
||||
from torch import einsum
|
||||
from tqdm import tqdm
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig
|
||||
@@ -96,6 +95,7 @@ from networks.lora import LoRANetwork
|
||||
import tools.original_control_net as original_control_net
|
||||
from tools.original_control_net import ControlNetInfo
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
from library.original_unet import FlashAttentionFunction
|
||||
|
||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||
|
||||
@@ -136,341 +136,36 @@ USE_CUTOUTS = False
|
||||
高速化のためのモジュール入れ替え
|
||||
"""
|
||||
|
||||
# FlashAttentionを使うCrossAttention
|
||||
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
||||
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
|
||||
|
||||
# constants
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
if mem_eff_attn:
|
||||
print("Enable memory efficient attention for U-Net")
|
||||
|
||||
EPSILON = 1e-6
|
||||
# これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い
|
||||
unet.set_use_memory_efficient_attention(False, True)
|
||||
elif xformers:
|
||||
print("Enable xformers for U-Net")
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||
|
||||
# helper functions
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
# flash attention forwards and backwards
|
||||
|
||||
# https://arxiv.org/abs/2205.14135
|
||||
|
||||
|
||||
class FlashAttentionFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
||||
"""Algorithm 2 in the paper"""
|
||||
|
||||
device = q.device
|
||||
dtype = q.dtype
|
||||
max_neg_value = -torch.finfo(q.dtype).max
|
||||
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
||||
|
||||
o = torch.zeros_like(q)
|
||||
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
||||
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
|
||||
|
||||
scale = q.shape[-1] ** -0.5
|
||||
|
||||
if not exists(mask):
|
||||
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
||||
else:
|
||||
mask = rearrange(mask, "b n -> b 1 1 n")
|
||||
mask = mask.split(q_bucket_size, dim=-1)
|
||||
|
||||
row_splits = zip(
|
||||
q.split(q_bucket_size, dim=-2),
|
||||
o.split(q_bucket_size, dim=-2),
|
||||
mask,
|
||||
all_row_sums.split(q_bucket_size, dim=-2),
|
||||
all_row_maxes.split(q_bucket_size, dim=-2),
|
||||
)
|
||||
|
||||
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
||||
q_start_index = ind * q_bucket_size - qk_len_diff
|
||||
|
||||
col_splits = zip(
|
||||
k.split(k_bucket_size, dim=-2),
|
||||
v.split(k_bucket_size, dim=-2),
|
||||
)
|
||||
|
||||
for k_ind, (kc, vc) in enumerate(col_splits):
|
||||
k_start_index = k_ind * k_bucket_size
|
||||
|
||||
attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
||||
|
||||
if exists(row_mask):
|
||||
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
||||
|
||||
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
||||
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
||||
q_start_index - k_start_index + 1
|
||||
)
|
||||
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
||||
|
||||
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
||||
attn_weights -= block_row_maxes
|
||||
exp_weights = torch.exp(attn_weights)
|
||||
|
||||
if exists(row_mask):
|
||||
exp_weights.masked_fill_(~row_mask, 0.0)
|
||||
|
||||
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
|
||||
|
||||
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
||||
|
||||
exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc)
|
||||
|
||||
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
||||
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
||||
|
||||
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
|
||||
|
||||
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
|
||||
|
||||
row_maxes.copy_(new_row_maxes)
|
||||
row_sums.copy_(new_row_sums)
|
||||
|
||||
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
||||
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
||||
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def backward(ctx, do):
|
||||
"""Algorithm 4 in the paper"""
|
||||
|
||||
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
||||
q, k, v, o, l, m = ctx.saved_tensors
|
||||
|
||||
device = q.device
|
||||
|
||||
max_neg_value = -torch.finfo(q.dtype).max
|
||||
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
||||
|
||||
dq = torch.zeros_like(q)
|
||||
dk = torch.zeros_like(k)
|
||||
dv = torch.zeros_like(v)
|
||||
|
||||
row_splits = zip(
|
||||
q.split(q_bucket_size, dim=-2),
|
||||
o.split(q_bucket_size, dim=-2),
|
||||
do.split(q_bucket_size, dim=-2),
|
||||
mask,
|
||||
l.split(q_bucket_size, dim=-2),
|
||||
m.split(q_bucket_size, dim=-2),
|
||||
dq.split(q_bucket_size, dim=-2),
|
||||
)
|
||||
|
||||
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
||||
q_start_index = ind * q_bucket_size - qk_len_diff
|
||||
|
||||
col_splits = zip(
|
||||
k.split(k_bucket_size, dim=-2),
|
||||
v.split(k_bucket_size, dim=-2),
|
||||
dk.split(k_bucket_size, dim=-2),
|
||||
dv.split(k_bucket_size, dim=-2),
|
||||
)
|
||||
|
||||
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
||||
k_start_index = k_ind * k_bucket_size
|
||||
|
||||
attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
||||
|
||||
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
||||
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
||||
q_start_index - k_start_index + 1
|
||||
)
|
||||
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
||||
|
||||
exp_attn_weights = torch.exp(attn_weights - mc)
|
||||
|
||||
if exists(row_mask):
|
||||
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
||||
|
||||
p = exp_attn_weights / lc
|
||||
|
||||
dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc)
|
||||
dp = einsum("... i d, ... j d -> ... i j", doc, vc)
|
||||
|
||||
D = (doc * oc).sum(dim=-1, keepdims=True)
|
||||
ds = p * scale * (dp - D)
|
||||
|
||||
dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc)
|
||||
dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc)
|
||||
|
||||
dqc.add_(dq_chunk)
|
||||
dkc.add_(dk_chunk)
|
||||
dvc.add_(dv_chunk)
|
||||
|
||||
return dq, dk, dv, None, None, None, None
|
||||
unet.set_use_memory_efficient_attention(True, False)
|
||||
|
||||
|
||||
# TODO common train_util.py
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
if mem_eff_attn:
|
||||
replace_unet_cross_attn_to_memory_efficient()
|
||||
elif xformers:
|
||||
replace_unet_cross_attn_to_xformers(unet)
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_memory_efficient():
|
||||
print("CrossAttention.forward has been replaced to FlashAttention (not xformers) and NAI style Hypernetwork")
|
||||
flash_func = FlashAttentionFunction
|
||||
|
||||
def forward_flash_attn(self, x, context=None, mask=None):
|
||||
q_bucket_size = 512
|
||||
k_bucket_size = 1024
|
||||
|
||||
h = self.heads
|
||||
q = self.to_q(x)
|
||||
|
||||
context = context if context is not None else x
|
||||
context = context.to(x.dtype)
|
||||
|
||||
if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
||||
context_k, context_v = self.hypernetwork.forward(x, context)
|
||||
context_k = context_k.to(x.dtype)
|
||||
context_v = context_v.to(x.dtype)
|
||||
else:
|
||||
context_k = context
|
||||
context_v = context
|
||||
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
||||
|
||||
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
|
||||
# diffusers 0.7.0~
|
||||
out = self.to_out[0](out)
|
||||
out = self.to_out[1](out)
|
||||
return out
|
||||
|
||||
# diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
||||
# TODO U-Net側に移す
|
||||
from library.original_unet import CrossAttention
|
||||
CrossAttention.forward = forward_flash_attn
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_xformers(unet:UNet2DConditionModel):
|
||||
print("CrossAttention.forward has been replaced to enable xformers and NAI style Hypernetwork")
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||
|
||||
unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
# def forward_xformers(self, x, context=None, mask=None):
|
||||
# h = self.heads
|
||||
# q_in = self.to_q(x)
|
||||
|
||||
# context = default(context, x)
|
||||
# context = context.to(x.dtype)
|
||||
|
||||
# if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
||||
# context_k, context_v = self.hypernetwork.forward(x, context)
|
||||
# context_k = context_k.to(x.dtype)
|
||||
# context_v = context_v.to(x.dtype)
|
||||
# else:
|
||||
# context_k = context
|
||||
# context_v = context
|
||||
|
||||
# k_in = self.to_k(context_k)
|
||||
# v_in = self.to_v(context_v)
|
||||
|
||||
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
||||
# del q_in, k_in, v_in
|
||||
|
||||
# q = q.contiguous()
|
||||
# k = k.contiguous()
|
||||
# v = v.contiguous()
|
||||
# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
||||
|
||||
# out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
||||
|
||||
# # diffusers 0.7.0~
|
||||
# out = self.to_out[0](out)
|
||||
# out = self.to_out[1](out)
|
||||
# return out
|
||||
|
||||
# diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||
|
||||
|
||||
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
||||
if mem_eff_attn:
|
||||
replace_vae_attn_to_memory_efficient()
|
||||
elif xformers:
|
||||
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
|
||||
print("Use Diffusers xformers for VAE")
|
||||
vae.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
"""
|
||||
# VAEがbfloat16でメモリ消費が大きい問題を解決する
|
||||
upsamplers = []
|
||||
for block in vae.decoder.up_blocks:
|
||||
if block.upsamplers is not None:
|
||||
upsamplers.extend(block.upsamplers)
|
||||
|
||||
def forward_upsample(_self, hidden_states, output_size=None):
|
||||
assert hidden_states.shape[1] == _self.channels
|
||||
if _self.use_conv_transpose:
|
||||
return _self.conv(hidden_states)
|
||||
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
assert output_size is None
|
||||
# repeat_interleaveはすごく遅いが、回数はあまり呼ばれないので許容する
|
||||
hidden_states = hidden_states.repeat_interleave(2, dim=-1)
|
||||
hidden_states = hidden_states.repeat_interleave(2, dim=-2)
|
||||
else:
|
||||
if hidden_states.shape[0] >= 64:
|
||||
hidden_states = hidden_states.contiguous()
|
||||
|
||||
# if `output_size` is passed we force the interpolation output
|
||||
# size and do not make use of `scale_factor=2`
|
||||
if output_size is None:
|
||||
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
else:
|
||||
hidden_states = torch.nn.functional.interpolate(hidden_states, size=output_size, mode="nearest")
|
||||
|
||||
if _self.use_conv:
|
||||
if _self.name == "conv":
|
||||
hidden_states = _self.conv(hidden_states)
|
||||
else:
|
||||
hidden_states = _self.Conv2d_0(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
# replace upsamplers
|
||||
for upsampler in upsamplers:
|
||||
# make new scope
|
||||
def make_replacer(upsampler):
|
||||
def forward(hidden_states, output_size=None):
|
||||
return forward_upsample(upsampler, hidden_states, output_size)
|
||||
|
||||
return forward
|
||||
|
||||
upsampler.forward = make_replacer(upsampler)
|
||||
"""
|
||||
|
||||
replace_vae_attn_to_xformers()
|
||||
|
||||
def replace_vae_attn_to_memory_efficient():
|
||||
print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)")
|
||||
flash_func = FlashAttentionFunction
|
||||
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
|
||||
flash_func =FlashAttentionFunction
|
||||
|
||||
def forward_flash_attn(self, hidden_states):
|
||||
print("forward_flash_attn")
|
||||
def forward_flash_attn(self, hidden_states, **kwargs):
|
||||
q_bucket_size = 512
|
||||
k_bucket_size = 1024
|
||||
|
||||
@@ -483,12 +178,12 @@ def replace_vae_attn_to_memory_efficient():
|
||||
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
||||
|
||||
# proj to q, k, v
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
query_proj = self.to_q(hidden_states)
|
||||
key_proj = self.to_k(hidden_states)
|
||||
value_proj = self.to_v(hidden_states)
|
||||
|
||||
query_proj, key_proj, value_proj = map(
|
||||
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
|
||||
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj)
|
||||
)
|
||||
|
||||
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
|
||||
@@ -496,14 +191,62 @@ def replace_vae_attn_to_memory_efficient():
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||
|
||||
# res connect and rescale
|
||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn
|
||||
diffusers.models.attention_processor.Attention.forward = forward_flash_attn
|
||||
|
||||
|
||||
def replace_vae_attn_to_xformers():
|
||||
print("VAE: Attention.forward has been replaced to xformers")
|
||||
import xformers.ops
|
||||
|
||||
def forward_xformers(self, hidden_states, **kwargs):
|
||||
residual = hidden_states
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
|
||||
# norm
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
||||
|
||||
# proj to q, k, v
|
||||
query_proj = self.to_q(hidden_states)
|
||||
key_proj = self.to_k(hidden_states)
|
||||
value_proj = self.to_v(hidden_states)
|
||||
|
||||
query_proj, key_proj, value_proj = map(
|
||||
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj)
|
||||
)
|
||||
|
||||
query_proj = query_proj.contiguous()
|
||||
key_proj = key_proj.contiguous()
|
||||
value_proj = value_proj.contiguous()
|
||||
out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
|
||||
# compute next hidden_states
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||
|
||||
# res connect and rescale
|
||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
diffusers.models.attention_processor.Attention.forward = forward_xformers
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user