Modify nn.MHA to attn with q/k/v

This commit is contained in:
Kohya S
2024-02-22 23:39:28 +09:00
parent 417f14d245
commit 3368fb1af7
7 changed files with 204 additions and 7 deletions

View File

@@ -20,7 +20,7 @@ import diffusers
import numpy as np import numpy as np
import torch import torch
from library.ipex_interop import init_ipex from library.device_utils import init_ipex, clean_memory, get_preferred_device
init_ipex() init_ipex()
@@ -338,7 +338,7 @@ class PipelineLike:
self.clip_vision_model: CLIPVisionModelWithProjection = None self.clip_vision_model: CLIPVisionModelWithProjection = None
self.clip_vision_processor: CLIPImageProcessor = None self.clip_vision_processor: CLIPImageProcessor = None
self.clip_vision_strength = 0.0 self.clip_vision_strength = 0.0
# Textual Inversion # Textual Inversion
self.token_replacements_list = [] self.token_replacements_list = []
for _ in range(len(self.text_encoders)): for _ in range(len(self.text_encoders)):

View File

@@ -5,6 +5,7 @@
import math import math
from types import SimpleNamespace from types import SimpleNamespace
from typing import List, Optional from typing import List, Optional
from einops import rearrange
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -148,7 +149,7 @@ class EfficientNetEncoder(nn.Module):
The method to make it usable like VAE. It should be separated properly, but it is a temporary response. The method to make it usable like VAE. It should be separated properly, but it is a temporary response.
""" """
# latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") # latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
# x is -1 to 1, so we need to convert it to 0 to 1, and then preprocess it with EfficientNet's preprocessing. # x is -1 to 1, so we need to convert it to 0 to 1, and then preprocess it with EfficientNet's preprocessing.
x = (x + 1) / 2 x = (x + 1) / 2
x = EFFNET_PREPROCESS(x) x = EFFNET_PREPROCESS(x)
@@ -172,6 +173,7 @@ from torch.nn import Conv2d
from torch.nn import Linear from torch.nn import Linear
r"""
class Attention2D(nn.Module): class Attention2D(nn.Module):
def __init__(self, c, nhead, dropout=0.0): def __init__(self, c, nhead, dropout=0.0):
super().__init__() super().__init__()
@@ -185,6 +187,119 @@ class Attention2D(nn.Module):
x = self.attn(x, kv, kv, need_weights=False)[0] x = self.attn(x, kv, kv, need_weights=False)[0]
x = x.permute(0, 2, 1).view(*orig_shape) x = x.permute(0, 2, 1).view(*orig_shape)
return x return x
"""
class Attention(nn.Module):
def __init__(self, c, nhead, dropout=0.0):
# dropout is for attn_output_weights, so we may not need it. however, if we use sdpa, we enable it.
# xformers and normal attn are not affected by dropout
super().__init__()
self.to_q = Linear(c, c, bias=True)
self.to_k = Linear(c, c, bias=True)
self.to_v = Linear(c, c, bias=True)
self.to_out = Linear(c, c, bias=True)
self.nhead = nhead
self.dropout = dropout
self.scale = (c // nhead) ** -0.5
# default is to use sdpa
self.use_memory_efficient_attention_xformers = False
self.use_sdpa = True
def set_use_xformers_or_sdpa(self, xformers, sdpa):
# print(f"Attention: set_use_xformers_or_sdpa: xformers={xformers}, sdpa={sdpa}")
self.use_memory_efficient_attention_xformers = xformers
self.use_sdpa = sdpa
def forward(self, q_in, k_in, v_in):
q_in = self.to_q(q_in)
k_in = self.to_k(k_in)
v_in = self.to_v(v_in)
if self.use_memory_efficient_attention_xformers:
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.nhead), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = self.forward_memory_efficient_xformers(q, k, v)
del q, k, v
out = rearrange(out, "b n h d -> b n (h d)", h=self.nhead)
elif self.use_sdpa:
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.nhead), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = self.forward_sdpa(q, k, v)
del q, k, v
out = rearrange(out, "b h n d -> b n (h d)", h=self.nhead)
else:
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=self.nhead), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = self._attention(q, k, v)
del q, k, v
out = rearrange(out, "(b h) n d -> b n (h d)", h=self.nhead)
return self.to_out(out)
def _attention(self, query, key, value):
# if self.upcast_attention:
# query = query.float()
# key = key.float()
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
attention_probs = attention_scores.softmax(dim=-1)
# cast back to the original dtype
attention_probs = attention_probs.to(value.dtype)
# compute attention output
hidden_states = torch.bmm(attention_probs, value)
return hidden_states
def forward_memory_efficient_xformers(self, q, k, v):
import xformers.ops
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
del q, k, v
return out
def forward_sdpa(self, q, k, v):
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=False)
return out
class Attention2D(nn.Module):
r"""
to_q/k/v を個別に重みをもつように変更
modified to have separate weights for to_q/k/v
"""
def __init__(self, c, nhead, dropout=0.0):
super().__init__()
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
self.attn = Attention(c, nhead, dropout=dropout) # , bias=True, batch_first=True)
def forward(self, x, kv, self_attn=False):
orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn:
kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0]
x = self.attn(x, kv, kv)
x = x.permute(0, 2, 1).view(*orig_shape)
return x
def set_use_xformers_or_sdpa(self, xformers, sdpa):
self.attn.set_use_xformers_or_sdpa(xformers, sdpa)
class LayerNorm2d(nn.LayerNorm): class LayerNorm2d(nn.LayerNorm):
@@ -262,6 +377,9 @@ class AttnBlock(nn.Module):
def set_gradient_checkpointing(self, value): def set_gradient_checkpointing(self, value):
self.gradient_checkpointing = value self.gradient_checkpointing = value
def set_use_xformers_or_sdpa(self, xformers, sdpa):
self.attention.set_use_xformers_or_sdpa(xformers, sdpa)
def forward_body(self, x, kv): def forward_body(self, x, kv):
kv = self.kv_mapper(kv) kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
@@ -657,6 +775,12 @@ class StageB(nn.Module):
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def set_use_xformers_or_sdpa(self, xformers, sdpa):
for block in self.down_blocks + self.up_blocks:
for layer in block:
if hasattr(layer, "set_use_xformers_or_sdpa"):
layer.set_use_xformers_or_sdpa(xformers, sdpa)
def gen_r_embedding(self, r, max_positions=10000): def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions r = r * max_positions
half_dim = self.c_r // 2 half_dim = self.c_r // 2
@@ -920,6 +1044,12 @@ class StageC(nn.Module):
if hasattr(layer, "set_gradient_checkpointing"): if hasattr(layer, "set_gradient_checkpointing"):
layer.set_gradient_checkpointing(value) layer.set_gradient_checkpointing(value)
def set_use_xformers_or_sdpa(self, xformers, sdpa):
for block in self.down_blocks + self.up_blocks:
for layer in block:
if hasattr(layer, "set_use_xformers_or_sdpa"):
layer.set_use_xformers_or_sdpa(xformers, sdpa)
def gen_r_embedding(self, r, max_positions=10000): def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions r = r * max_positions
half_dim = self.c_r // 2 half_dim = self.c_r // 2

View File

@@ -103,6 +103,9 @@ def load_stage_c_model(stage_c_checkpoint_path, dtype=None, device="cpu") -> sc.
generator_c = sc.StageC() generator_c = sc.StageC()
logger.info(f"Loading Stage C generator from {stage_c_checkpoint_path}") logger.info(f"Loading Stage C generator from {stage_c_checkpoint_path}")
stage_c_checkpoint = load_file(stage_c_checkpoint_path) stage_c_checkpoint = load_file(stage_c_checkpoint_path)
stage_c_checkpoint = convert_state_dict_mha_to_normal_attn(stage_c_checkpoint)
logger.info(f"Loading state dict") logger.info(f"Loading state dict")
info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype) info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype)
logger.info(info) logger.info(info)
@@ -115,6 +118,9 @@ def load_stage_b_model(stage_b_checkpoint_path, dtype=None, device="cpu") -> sc.
generator_b = sc.StageB() generator_b = sc.StageB()
logger.info(f"Loading Stage B generator from {stage_b_checkpoint_path}") logger.info(f"Loading Stage B generator from {stage_b_checkpoint_path}")
stage_b_checkpoint = load_file(stage_b_checkpoint_path) stage_b_checkpoint = load_file(stage_b_checkpoint_path)
stage_b_checkpoint = convert_state_dict_mha_to_normal_attn(stage_b_checkpoint)
logger.info(f"Loading state dict") logger.info(f"Loading state dict")
info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, device, dtype=dtype) info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, device, dtype=dtype)
logger.info(info) logger.info(info)
@@ -189,6 +195,55 @@ def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") ->
return previewer return previewer
def convert_state_dict_mha_to_normal_attn(state_dict):
# convert nn.MultiheadAttention to to_q/k/v and to_out
print("convert_state_dict_mha_to_normal_attn")
for key in list(state_dict.keys()):
if "attention.attn." in key:
if "in_proj_bias" in key:
value = state_dict.pop(key)
qkv = torch.chunk(value, 3, dim=0)
state_dict[key.replace("in_proj_bias", "to_q.bias")] = qkv[0]
state_dict[key.replace("in_proj_bias", "to_k.bias")] = qkv[1]
state_dict[key.replace("in_proj_bias", "to_v.bias")] = qkv[2]
elif "in_proj_weight" in key:
value = state_dict.pop(key)
qkv = torch.chunk(value, 3, dim=0)
state_dict[key.replace("in_proj_weight", "to_q.weight")] = qkv[0]
state_dict[key.replace("in_proj_weight", "to_k.weight")] = qkv[1]
state_dict[key.replace("in_proj_weight", "to_v.weight")] = qkv[2]
elif "out_proj.bias" in key:
value = state_dict.pop(key)
state_dict[key.replace("out_proj.bias", "to_out.bias")] = value
elif "out_proj.weight" in key:
value = state_dict.pop(key)
state_dict[key.replace("out_proj.weight", "to_out.weight")] = value
return state_dict
def convert_state_dict_normal_attn_to_mha(state_dict):
# convert to_q/k/v and to_out to nn.MultiheadAttention
for key in list(state_dict.keys()):
if "attention.attn." in key:
if "to_q.bias" in key:
q = state_dict.pop(key)
k = state_dict.pop(key.replace("to_q.bias", "to_k.bias"))
v = state_dict.pop(key.replace("to_q.bias", "to_v.bias"))
state_dict[key.replace("to_q.bias", "in_proj_bias")] = torch.cat([q, k, v])
elif "to_q.weight" in key:
q = state_dict.pop(key)
k = state_dict.pop(key.replace("to_q.weight", "to_k.weight"))
v = state_dict.pop(key.replace("to_q.weight", "to_v.weight"))
state_dict[key.replace("to_q.weight", "in_proj_weight")] = torch.cat([q, k, v])
elif "to_out.bias" in key:
v = state_dict.pop(key)
state_dict[key.replace("to_out.bias", "out_proj.bias")] = v
elif "to_out.weight" in key:
v = state_dict.pop(key)
state_dict[key.replace("to_out.weight", "out_proj.weight")] = v
return state_dict
def get_sai_model_spec(args, lora=False): def get_sai_model_spec(args, lora=False):
timestamp = time.time() timestamp = time.time()
@@ -230,6 +285,8 @@ def stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadat
if save_dtype is not None: if save_dtype is not None:
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()} state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
state_dict = convert_state_dict_normal_attn_to_mha(state_dict)
save_file(state_dict, ckpt_file, metadata=sai_metadata) save_file(state_dict, ckpt_file, metadata=sai_metadata)
# save text model # save text model

View File

@@ -29,7 +29,7 @@ def main(file):
for key, value in values: for key, value in values:
value = value.to(torch.float32) value = value.to(torch.float32)
logger.info(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -40,9 +40,15 @@ def main(args):
generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device) generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device)
generator_c.eval().requires_grad_(False).to(loading_device) generator_c.eval().requires_grad_(False).to(loading_device)
# if args.xformers or args.sdpa:
print(f"Stage C: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
generator_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device) generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device)
generator_b.eval().requires_grad_(False).to(loading_device) generator_b.eval().requires_grad_(False).to(loading_device)
# if args.xformers or args.sdpa:
print(f"Stage B: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
generator_b.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
# CLIP encoders # CLIP encoders
tokenizer = sc_utils.load_tokenizer(args) tokenizer = sc_utils.load_tokenizer(args)
@@ -332,6 +338,8 @@ if __name__ == "__main__":
sc_utils.add_text_model_arguments(parser) sc_utils.add_text_model_arguments(parser)
parser.add_argument("--bf16", action="store_true") parser.add_argument("--bf16", action="store_true")
parser.add_argument("--fp16", action="store_true") parser.add_argument("--fp16", action="store_true")
parser.add_argument("--xformers", action="store_true")
parser.add_argument("--sdpa", action="store_true")
parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先") parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先")
parser.add_argument("--lowvram", action="store_true", help="if specified, use low VRAM mode") parser.add_argument("--lowvram", action="store_true", help="if specified, use low VRAM mode")
parser.add_argument( parser.add_argument(

View File

@@ -295,10 +295,9 @@ class NetworkTrainer:
# text_encoder is List[CLIPTextModel] or CLIPTextModel # text_encoder is List[CLIPTextModel] or CLIPTextModel
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
# # モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
# train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) # train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える stage_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
# vae.set_use_memory_efficient_attention_xformers(args.xformers)
# 差分追加学習のためにモデルを読み込む # 差分追加学習のためにモデルを読み込む
sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.dirname(__file__))

View File

@@ -147,6 +147,9 @@ def train(args):
else: else:
previewer = None previewer = None
# モデルに xformers とか memory efficient attention を組み込む
stage_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
# 学習を準備する # 学習を準備する
if cache_latents: if cache_latents:
effnet.to(accelerator.device, dtype=effnet_dtype) effnet.to(accelerator.device, dtype=effnet_dtype)