diff --git a/gen_img.py b/gen_img.py index a24220a0..daf88d2a 100644 --- a/gen_img.py +++ b/gen_img.py @@ -20,7 +20,7 @@ import diffusers import numpy as np import torch -from library.ipex_interop import init_ipex +from library.device_utils import init_ipex, clean_memory, get_preferred_device init_ipex() @@ -338,7 +338,7 @@ class PipelineLike: self.clip_vision_model: CLIPVisionModelWithProjection = None self.clip_vision_processor: CLIPImageProcessor = None self.clip_vision_strength = 0.0 - + # Textual Inversion self.token_replacements_list = [] for _ in range(len(self.text_encoders)): diff --git a/library/stable_cascade.py b/library/stable_cascade.py index 624bf7ce..ff51966d 100644 --- a/library/stable_cascade.py +++ b/library/stable_cascade.py @@ -5,6 +5,7 @@ import math from types import SimpleNamespace from typing import List, Optional +from einops import rearrange import numpy as np import torch import torch.nn as nn @@ -148,7 +149,7 @@ class EfficientNetEncoder(nn.Module): The method to make it usable like VAE. It should be separated properly, but it is a temporary response. """ # latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - + # x is -1 to 1, so we need to convert it to 0 to 1, and then preprocess it with EfficientNet's preprocessing. x = (x + 1) / 2 x = EFFNET_PREPROCESS(x) @@ -172,6 +173,7 @@ from torch.nn import Conv2d from torch.nn import Linear +r""" class Attention2D(nn.Module): def __init__(self, c, nhead, dropout=0.0): super().__init__() @@ -185,6 +187,119 @@ class Attention2D(nn.Module): x = self.attn(x, kv, kv, need_weights=False)[0] x = x.permute(0, 2, 1).view(*orig_shape) return x +""" + + +class Attention(nn.Module): + def __init__(self, c, nhead, dropout=0.0): + # dropout is for attn_output_weights, so we may not need it. however, if we use sdpa, we enable it. + # xformers and normal attn are not affected by dropout + super().__init__() + + self.to_q = Linear(c, c, bias=True) + self.to_k = Linear(c, c, bias=True) + self.to_v = Linear(c, c, bias=True) + self.to_out = Linear(c, c, bias=True) + self.nhead = nhead + self.dropout = dropout + self.scale = (c // nhead) ** -0.5 + + # default is to use sdpa + self.use_memory_efficient_attention_xformers = False + self.use_sdpa = True + + def set_use_xformers_or_sdpa(self, xformers, sdpa): + # print(f"Attention: set_use_xformers_or_sdpa: xformers={xformers}, sdpa={sdpa}") + self.use_memory_efficient_attention_xformers = xformers + self.use_sdpa = sdpa + + def forward(self, q_in, k_in, v_in): + q_in = self.to_q(q_in) + k_in = self.to_k(k_in) + v_in = self.to_v(v_in) + + if self.use_memory_efficient_attention_xformers: + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.nhead), (q_in, k_in, v_in)) + del q_in, k_in, v_in + out = self.forward_memory_efficient_xformers(q, k, v) + del q, k, v + out = rearrange(out, "b n h d -> b n (h d)", h=self.nhead) + elif self.use_sdpa: + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.nhead), (q_in, k_in, v_in)) + del q_in, k_in, v_in + out = self.forward_sdpa(q, k, v) + del q, k, v + out = rearrange(out, "b h n d -> b n (h d)", h=self.nhead) + else: + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=self.nhead), (q_in, k_in, v_in)) + del q_in, k_in, v_in + out = self._attention(q, k, v) + del q, k, v + out = rearrange(out, "(b h) n d -> b n (h d)", h=self.nhead) + + return self.to_out(out) + + def _attention(self, query, key, value): + # if self.upcast_attention: + # query = query.float() + # key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + return hidden_states + + def forward_memory_efficient_xformers(self, q, k, v): + import xformers.ops + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + del q, k, v + + return out + + def forward_sdpa(self, q, k, v): + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=False) + return out + + +class Attention2D(nn.Module): + r""" + to_q/k/v を個別に重みをもつように変更 + modified to have separate weights for to_q/k/v + """ + + def __init__(self, c, nhead, dropout=0.0): + super().__init__() + # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) + self.attn = Attention(c, nhead, dropout=dropout) # , bias=True, batch_first=True) + + def forward(self, x, kv, self_attn=False): + orig_shape = x.shape + x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + if self_attn: + kv = torch.cat([x, kv], dim=1) + # x = self.attn(x, kv, kv, need_weights=False)[0] + x = self.attn(x, kv, kv) + x = x.permute(0, 2, 1).view(*orig_shape) + return x + + def set_use_xformers_or_sdpa(self, xformers, sdpa): + self.attn.set_use_xformers_or_sdpa(xformers, sdpa) class LayerNorm2d(nn.LayerNorm): @@ -262,6 +377,9 @@ class AttnBlock(nn.Module): def set_gradient_checkpointing(self, value): self.gradient_checkpointing = value + def set_use_xformers_or_sdpa(self, xformers, sdpa): + self.attention.set_use_xformers_or_sdpa(xformers, sdpa) + def forward_body(self, x, kv): kv = self.kv_mapper(kv) x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) @@ -657,6 +775,12 @@ class StageB(nn.Module): if m.bias is not None: nn.init.constant_(m.bias, 0) + def set_use_xformers_or_sdpa(self, xformers, sdpa): + for block in self.down_blocks + self.up_blocks: + for layer in block: + if hasattr(layer, "set_use_xformers_or_sdpa"): + layer.set_use_xformers_or_sdpa(xformers, sdpa) + def gen_r_embedding(self, r, max_positions=10000): r = r * max_positions half_dim = self.c_r // 2 @@ -920,6 +1044,12 @@ class StageC(nn.Module): if hasattr(layer, "set_gradient_checkpointing"): layer.set_gradient_checkpointing(value) + def set_use_xformers_or_sdpa(self, xformers, sdpa): + for block in self.down_blocks + self.up_blocks: + for layer in block: + if hasattr(layer, "set_use_xformers_or_sdpa"): + layer.set_use_xformers_or_sdpa(xformers, sdpa) + def gen_r_embedding(self, r, max_positions=10000): r = r * max_positions half_dim = self.c_r // 2 diff --git a/library/stable_cascade_utils.py b/library/stable_cascade_utils.py index 83bef254..571d44ed 100644 --- a/library/stable_cascade_utils.py +++ b/library/stable_cascade_utils.py @@ -103,6 +103,9 @@ def load_stage_c_model(stage_c_checkpoint_path, dtype=None, device="cpu") -> sc. generator_c = sc.StageC() logger.info(f"Loading Stage C generator from {stage_c_checkpoint_path}") stage_c_checkpoint = load_file(stage_c_checkpoint_path) + + stage_c_checkpoint = convert_state_dict_mha_to_normal_attn(stage_c_checkpoint) + logger.info(f"Loading state dict") info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype) logger.info(info) @@ -115,6 +118,9 @@ def load_stage_b_model(stage_b_checkpoint_path, dtype=None, device="cpu") -> sc. generator_b = sc.StageB() logger.info(f"Loading Stage B generator from {stage_b_checkpoint_path}") stage_b_checkpoint = load_file(stage_b_checkpoint_path) + + stage_b_checkpoint = convert_state_dict_mha_to_normal_attn(stage_b_checkpoint) + logger.info(f"Loading state dict") info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, device, dtype=dtype) logger.info(info) @@ -189,6 +195,55 @@ def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") -> return previewer +def convert_state_dict_mha_to_normal_attn(state_dict): + # convert nn.MultiheadAttention to to_q/k/v and to_out + print("convert_state_dict_mha_to_normal_attn") + for key in list(state_dict.keys()): + if "attention.attn." in key: + if "in_proj_bias" in key: + value = state_dict.pop(key) + qkv = torch.chunk(value, 3, dim=0) + state_dict[key.replace("in_proj_bias", "to_q.bias")] = qkv[0] + state_dict[key.replace("in_proj_bias", "to_k.bias")] = qkv[1] + state_dict[key.replace("in_proj_bias", "to_v.bias")] = qkv[2] + elif "in_proj_weight" in key: + value = state_dict.pop(key) + qkv = torch.chunk(value, 3, dim=0) + state_dict[key.replace("in_proj_weight", "to_q.weight")] = qkv[0] + state_dict[key.replace("in_proj_weight", "to_k.weight")] = qkv[1] + state_dict[key.replace("in_proj_weight", "to_v.weight")] = qkv[2] + elif "out_proj.bias" in key: + value = state_dict.pop(key) + state_dict[key.replace("out_proj.bias", "to_out.bias")] = value + elif "out_proj.weight" in key: + value = state_dict.pop(key) + state_dict[key.replace("out_proj.weight", "to_out.weight")] = value + return state_dict + + +def convert_state_dict_normal_attn_to_mha(state_dict): + # convert to_q/k/v and to_out to nn.MultiheadAttention + for key in list(state_dict.keys()): + if "attention.attn." in key: + if "to_q.bias" in key: + q = state_dict.pop(key) + k = state_dict.pop(key.replace("to_q.bias", "to_k.bias")) + v = state_dict.pop(key.replace("to_q.bias", "to_v.bias")) + state_dict[key.replace("to_q.bias", "in_proj_bias")] = torch.cat([q, k, v]) + elif "to_q.weight" in key: + q = state_dict.pop(key) + k = state_dict.pop(key.replace("to_q.weight", "to_k.weight")) + v = state_dict.pop(key.replace("to_q.weight", "to_v.weight")) + state_dict[key.replace("to_q.weight", "in_proj_weight")] = torch.cat([q, k, v]) + elif "to_out.bias" in key: + v = state_dict.pop(key) + state_dict[key.replace("to_out.bias", "out_proj.bias")] = v + elif "to_out.weight" in key: + v = state_dict.pop(key) + state_dict[key.replace("to_out.weight", "out_proj.weight")] = v + return state_dict + + def get_sai_model_spec(args, lora=False): timestamp = time.time() @@ -230,6 +285,8 @@ def stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadat if save_dtype is not None: state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()} + state_dict = convert_state_dict_normal_attn_to_mha(state_dict) + save_file(state_dict, ckpt_file, metadata=sai_metadata) # save text model diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 6ec60a89..2ee162cc 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -29,7 +29,7 @@ def main(file): for key, value in values: value = value.to(torch.float32) - logger.info(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") + print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") def setup_parser() -> argparse.ArgumentParser: diff --git a/stable_cascade_gen_img.py b/stable_cascade_gen_img.py index 76fe3b39..827941d0 100644 --- a/stable_cascade_gen_img.py +++ b/stable_cascade_gen_img.py @@ -40,9 +40,15 @@ def main(args): generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device) generator_c.eval().requires_grad_(False).to(loading_device) + # if args.xformers or args.sdpa: + print(f"Stage C: use_xformers_or_sdpa: {args.xformers} {args.sdpa}") + generator_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa) generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device) generator_b.eval().requires_grad_(False).to(loading_device) + # if args.xformers or args.sdpa: + print(f"Stage B: use_xformers_or_sdpa: {args.xformers} {args.sdpa}") + generator_b.set_use_xformers_or_sdpa(args.xformers, args.sdpa) # CLIP encoders tokenizer = sc_utils.load_tokenizer(args) @@ -332,6 +338,8 @@ if __name__ == "__main__": sc_utils.add_text_model_arguments(parser) parser.add_argument("--bf16", action="store_true") parser.add_argument("--fp16", action="store_true") + parser.add_argument("--xformers", action="store_true") + parser.add_argument("--sdpa", action="store_true") parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先") parser.add_argument("--lowvram", action="store_true", help="if specified, use low VRAM mode") parser.add_argument( diff --git a/stable_cascade_train_c_network.py b/stable_cascade_train_c_network.py index f7efc60c..b7f4c3a9 100644 --- a/stable_cascade_train_c_network.py +++ b/stable_cascade_train_c_network.py @@ -295,10 +295,9 @@ class NetworkTrainer: # text_encoder is List[CLIPTextModel] or CLIPTextModel text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] - # # モデルに xformers とか memory efficient attention を組み込む + # モデルに xformers とか memory efficient attention を組み込む # train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - # if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - # vae.set_use_memory_efficient_attention_xformers(args.xformers) + stage_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa) # 差分追加学習のためにモデルを読み込む sys.path.append(os.path.dirname(__file__)) diff --git a/stable_cascade_train_stage_c.py b/stable_cascade_train_stage_c.py index de3bfed8..fa9ed2f4 100644 --- a/stable_cascade_train_stage_c.py +++ b/stable_cascade_train_stage_c.py @@ -147,6 +147,9 @@ def train(args): else: previewer = None + # モデルに xformers とか memory efficient attention を組み込む + stage_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa) + # 学習を準備する if cache_latents: effnet.to(accelerator.device, dtype=effnet_dtype)