Modify nn.MHA to attn with q/k/v

This commit is contained in:
Kohya S
2024-02-22 23:39:28 +09:00
parent 417f14d245
commit 3368fb1af7
7 changed files with 204 additions and 7 deletions

View File

@@ -20,7 +20,7 @@ import diffusers
import numpy as np
import torch
from library.ipex_interop import init_ipex
from library.device_utils import init_ipex, clean_memory, get_preferred_device
init_ipex()
@@ -338,7 +338,7 @@ class PipelineLike:
self.clip_vision_model: CLIPVisionModelWithProjection = None
self.clip_vision_processor: CLIPImageProcessor = None
self.clip_vision_strength = 0.0
# Textual Inversion
self.token_replacements_list = []
for _ in range(len(self.text_encoders)):

View File

@@ -5,6 +5,7 @@
import math
from types import SimpleNamespace
from typing import List, Optional
from einops import rearrange
import numpy as np
import torch
import torch.nn as nn
@@ -148,7 +149,7 @@ class EfficientNetEncoder(nn.Module):
The method to make it usable like VAE. It should be separated properly, but it is a temporary response.
"""
# latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
# x is -1 to 1, so we need to convert it to 0 to 1, and then preprocess it with EfficientNet's preprocessing.
x = (x + 1) / 2
x = EFFNET_PREPROCESS(x)
@@ -172,6 +173,7 @@ from torch.nn import Conv2d
from torch.nn import Linear
r"""
class Attention2D(nn.Module):
def __init__(self, c, nhead, dropout=0.0):
super().__init__()
@@ -185,6 +187,119 @@ class Attention2D(nn.Module):
x = self.attn(x, kv, kv, need_weights=False)[0]
x = x.permute(0, 2, 1).view(*orig_shape)
return x
"""
class Attention(nn.Module):
def __init__(self, c, nhead, dropout=0.0):
# dropout is for attn_output_weights, so we may not need it. however, if we use sdpa, we enable it.
# xformers and normal attn are not affected by dropout
super().__init__()
self.to_q = Linear(c, c, bias=True)
self.to_k = Linear(c, c, bias=True)
self.to_v = Linear(c, c, bias=True)
self.to_out = Linear(c, c, bias=True)
self.nhead = nhead
self.dropout = dropout
self.scale = (c // nhead) ** -0.5
# default is to use sdpa
self.use_memory_efficient_attention_xformers = False
self.use_sdpa = True
def set_use_xformers_or_sdpa(self, xformers, sdpa):
# print(f"Attention: set_use_xformers_or_sdpa: xformers={xformers}, sdpa={sdpa}")
self.use_memory_efficient_attention_xformers = xformers
self.use_sdpa = sdpa
def forward(self, q_in, k_in, v_in):
q_in = self.to_q(q_in)
k_in = self.to_k(k_in)
v_in = self.to_v(v_in)
if self.use_memory_efficient_attention_xformers:
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.nhead), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = self.forward_memory_efficient_xformers(q, k, v)
del q, k, v
out = rearrange(out, "b n h d -> b n (h d)", h=self.nhead)
elif self.use_sdpa:
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.nhead), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = self.forward_sdpa(q, k, v)
del q, k, v
out = rearrange(out, "b h n d -> b n (h d)", h=self.nhead)
else:
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=self.nhead), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = self._attention(q, k, v)
del q, k, v
out = rearrange(out, "(b h) n d -> b n (h d)", h=self.nhead)
return self.to_out(out)
def _attention(self, query, key, value):
# if self.upcast_attention:
# query = query.float()
# key = key.float()
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
attention_probs = attention_scores.softmax(dim=-1)
# cast back to the original dtype
attention_probs = attention_probs.to(value.dtype)
# compute attention output
hidden_states = torch.bmm(attention_probs, value)
return hidden_states
def forward_memory_efficient_xformers(self, q, k, v):
import xformers.ops
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
del q, k, v
return out
def forward_sdpa(self, q, k, v):
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=False)
return out
class Attention2D(nn.Module):
r"""
to_q/k/v を個別に重みをもつように変更
modified to have separate weights for to_q/k/v
"""
def __init__(self, c, nhead, dropout=0.0):
super().__init__()
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
self.attn = Attention(c, nhead, dropout=dropout) # , bias=True, batch_first=True)
def forward(self, x, kv, self_attn=False):
orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn:
kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0]
x = self.attn(x, kv, kv)
x = x.permute(0, 2, 1).view(*orig_shape)
return x
def set_use_xformers_or_sdpa(self, xformers, sdpa):
self.attn.set_use_xformers_or_sdpa(xformers, sdpa)
class LayerNorm2d(nn.LayerNorm):
@@ -262,6 +377,9 @@ class AttnBlock(nn.Module):
def set_gradient_checkpointing(self, value):
self.gradient_checkpointing = value
def set_use_xformers_or_sdpa(self, xformers, sdpa):
self.attention.set_use_xformers_or_sdpa(xformers, sdpa)
def forward_body(self, x, kv):
kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
@@ -657,6 +775,12 @@ class StageB(nn.Module):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def set_use_xformers_or_sdpa(self, xformers, sdpa):
for block in self.down_blocks + self.up_blocks:
for layer in block:
if hasattr(layer, "set_use_xformers_or_sdpa"):
layer.set_use_xformers_or_sdpa(xformers, sdpa)
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
@@ -920,6 +1044,12 @@ class StageC(nn.Module):
if hasattr(layer, "set_gradient_checkpointing"):
layer.set_gradient_checkpointing(value)
def set_use_xformers_or_sdpa(self, xformers, sdpa):
for block in self.down_blocks + self.up_blocks:
for layer in block:
if hasattr(layer, "set_use_xformers_or_sdpa"):
layer.set_use_xformers_or_sdpa(xformers, sdpa)
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2

View File

@@ -103,6 +103,9 @@ def load_stage_c_model(stage_c_checkpoint_path, dtype=None, device="cpu") -> sc.
generator_c = sc.StageC()
logger.info(f"Loading Stage C generator from {stage_c_checkpoint_path}")
stage_c_checkpoint = load_file(stage_c_checkpoint_path)
stage_c_checkpoint = convert_state_dict_mha_to_normal_attn(stage_c_checkpoint)
logger.info(f"Loading state dict")
info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype)
logger.info(info)
@@ -115,6 +118,9 @@ def load_stage_b_model(stage_b_checkpoint_path, dtype=None, device="cpu") -> sc.
generator_b = sc.StageB()
logger.info(f"Loading Stage B generator from {stage_b_checkpoint_path}")
stage_b_checkpoint = load_file(stage_b_checkpoint_path)
stage_b_checkpoint = convert_state_dict_mha_to_normal_attn(stage_b_checkpoint)
logger.info(f"Loading state dict")
info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, device, dtype=dtype)
logger.info(info)
@@ -189,6 +195,55 @@ def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") ->
return previewer
def convert_state_dict_mha_to_normal_attn(state_dict):
# convert nn.MultiheadAttention to to_q/k/v and to_out
print("convert_state_dict_mha_to_normal_attn")
for key in list(state_dict.keys()):
if "attention.attn." in key:
if "in_proj_bias" in key:
value = state_dict.pop(key)
qkv = torch.chunk(value, 3, dim=0)
state_dict[key.replace("in_proj_bias", "to_q.bias")] = qkv[0]
state_dict[key.replace("in_proj_bias", "to_k.bias")] = qkv[1]
state_dict[key.replace("in_proj_bias", "to_v.bias")] = qkv[2]
elif "in_proj_weight" in key:
value = state_dict.pop(key)
qkv = torch.chunk(value, 3, dim=0)
state_dict[key.replace("in_proj_weight", "to_q.weight")] = qkv[0]
state_dict[key.replace("in_proj_weight", "to_k.weight")] = qkv[1]
state_dict[key.replace("in_proj_weight", "to_v.weight")] = qkv[2]
elif "out_proj.bias" in key:
value = state_dict.pop(key)
state_dict[key.replace("out_proj.bias", "to_out.bias")] = value
elif "out_proj.weight" in key:
value = state_dict.pop(key)
state_dict[key.replace("out_proj.weight", "to_out.weight")] = value
return state_dict
def convert_state_dict_normal_attn_to_mha(state_dict):
# convert to_q/k/v and to_out to nn.MultiheadAttention
for key in list(state_dict.keys()):
if "attention.attn." in key:
if "to_q.bias" in key:
q = state_dict.pop(key)
k = state_dict.pop(key.replace("to_q.bias", "to_k.bias"))
v = state_dict.pop(key.replace("to_q.bias", "to_v.bias"))
state_dict[key.replace("to_q.bias", "in_proj_bias")] = torch.cat([q, k, v])
elif "to_q.weight" in key:
q = state_dict.pop(key)
k = state_dict.pop(key.replace("to_q.weight", "to_k.weight"))
v = state_dict.pop(key.replace("to_q.weight", "to_v.weight"))
state_dict[key.replace("to_q.weight", "in_proj_weight")] = torch.cat([q, k, v])
elif "to_out.bias" in key:
v = state_dict.pop(key)
state_dict[key.replace("to_out.bias", "out_proj.bias")] = v
elif "to_out.weight" in key:
v = state_dict.pop(key)
state_dict[key.replace("to_out.weight", "out_proj.weight")] = v
return state_dict
def get_sai_model_spec(args, lora=False):
timestamp = time.time()
@@ -230,6 +285,8 @@ def stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadat
if save_dtype is not None:
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
state_dict = convert_state_dict_normal_attn_to_mha(state_dict)
save_file(state_dict, ckpt_file, metadata=sai_metadata)
# save text model

View File

@@ -29,7 +29,7 @@ def main(file):
for key, value in values:
value = value.to(torch.float32)
logger.info(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -40,9 +40,15 @@ def main(args):
generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device)
generator_c.eval().requires_grad_(False).to(loading_device)
# if args.xformers or args.sdpa:
print(f"Stage C: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
generator_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device)
generator_b.eval().requires_grad_(False).to(loading_device)
# if args.xformers or args.sdpa:
print(f"Stage B: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
generator_b.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
# CLIP encoders
tokenizer = sc_utils.load_tokenizer(args)
@@ -332,6 +338,8 @@ if __name__ == "__main__":
sc_utils.add_text_model_arguments(parser)
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--xformers", action="store_true")
parser.add_argument("--sdpa", action="store_true")
parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先")
parser.add_argument("--lowvram", action="store_true", help="if specified, use low VRAM mode")
parser.add_argument(

View File

@@ -295,10 +295,9 @@ class NetworkTrainer:
# text_encoder is List[CLIPTextModel] or CLIPTextModel
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
# # モデルに xformers とか memory efficient attention を組み込む
# モデルに xformers とか memory efficient attention を組み込む
# train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
# vae.set_use_memory_efficient_attention_xformers(args.xformers)
stage_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
# 差分追加学習のためにモデルを読み込む
sys.path.append(os.path.dirname(__file__))

View File

@@ -147,6 +147,9 @@ def train(args):
else:
previewer = None
# モデルに xformers とか memory efficient attention を組み込む
stage_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
# 学習を準備する
if cache_latents:
effnet.to(accelerator.device, dtype=effnet_dtype)