mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix eps value, enable xformers, etc.
This commit is contained in:
@@ -317,7 +317,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
|
|||||||
if mem_eff_attn:
|
if mem_eff_attn:
|
||||||
replace_unet_cross_attn_to_memory_efficient()
|
replace_unet_cross_attn_to_memory_efficient()
|
||||||
elif xformers:
|
elif xformers:
|
||||||
replace_unet_cross_attn_to_xformers()
|
replace_unet_cross_attn_to_xformers(unet)
|
||||||
|
|
||||||
|
|
||||||
def replace_unet_cross_attn_to_memory_efficient():
|
def replace_unet_cross_attn_to_memory_efficient():
|
||||||
@@ -357,50 +357,55 @@ def replace_unet_cross_attn_to_memory_efficient():
|
|||||||
out = self.to_out[1](out)
|
out = self.to_out[1](out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
# diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
||||||
|
# TODO U-Net側に移す
|
||||||
|
from library.original_unet import CrossAttention
|
||||||
|
CrossAttention.forward = forward_flash_attn
|
||||||
|
|
||||||
|
|
||||||
def replace_unet_cross_attn_to_xformers():
|
def replace_unet_cross_attn_to_xformers(unet:UNet2DConditionModel):
|
||||||
print("CrossAttention.forward has been replaced to enable xformers and NAI style Hypernetwork")
|
print("CrossAttention.forward has been replaced to enable xformers and NAI style Hypernetwork")
|
||||||
try:
|
try:
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("No xformers / xformersがインストールされていないようです")
|
raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||||
|
|
||||||
|
unet.set_use_memory_efficient_attention_xformers(True)
|
||||||
|
|
||||||
def forward_xformers(self, x, context=None, mask=None):
|
# def forward_xformers(self, x, context=None, mask=None):
|
||||||
h = self.heads
|
# h = self.heads
|
||||||
q_in = self.to_q(x)
|
# q_in = self.to_q(x)
|
||||||
|
|
||||||
context = default(context, x)
|
# context = default(context, x)
|
||||||
context = context.to(x.dtype)
|
# context = context.to(x.dtype)
|
||||||
|
|
||||||
if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
# if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
||||||
context_k, context_v = self.hypernetwork.forward(x, context)
|
# context_k, context_v = self.hypernetwork.forward(x, context)
|
||||||
context_k = context_k.to(x.dtype)
|
# context_k = context_k.to(x.dtype)
|
||||||
context_v = context_v.to(x.dtype)
|
# context_v = context_v.to(x.dtype)
|
||||||
else:
|
# else:
|
||||||
context_k = context
|
# context_k = context
|
||||||
context_v = context
|
# context_v = context
|
||||||
|
|
||||||
k_in = self.to_k(context_k)
|
# k_in = self.to_k(context_k)
|
||||||
v_in = self.to_v(context_v)
|
# v_in = self.to_v(context_v)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
# del q_in, k_in, v_in
|
||||||
|
|
||||||
q = q.contiguous()
|
# q = q.contiguous()
|
||||||
k = k.contiguous()
|
# k = k.contiguous()
|
||||||
v = v.contiguous()
|
# v = v.contiguous()
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
||||||
|
|
||||||
out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
# out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
||||||
|
|
||||||
# diffusers 0.7.0~
|
# # diffusers 0.7.0~
|
||||||
out = self.to_out[0](out)
|
# out = self.to_out[0](out)
|
||||||
out = self.to_out[1](out)
|
# out = self.to_out[1](out)
|
||||||
return out
|
# return out
|
||||||
|
|
||||||
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
# diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||||
|
|
||||||
|
|
||||||
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
# Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる
|
# Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる
|
||||||
|
# 条件分岐等で不要な部分は削除している
|
||||||
# コードの多くはDiffusersからコピーしている
|
# コードの多くはDiffusersからコピーしている
|
||||||
# コードが冗長になる部分はコメント等を適宜削除する
|
|
||||||
# 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある
|
# 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある
|
||||||
|
|
||||||
# Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers.
|
# Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers.
|
||||||
# Remove redundant code by deleting comments, etc. as appropriate
|
# Unnecessary parts are deleted by condition branching.
|
||||||
# As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2
|
# As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -111,6 +111,7 @@ from typing import Dict, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
|
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
|
||||||
TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
|
TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
|
||||||
@@ -121,8 +122,8 @@ LAYERS_PER_BLOCK: int = 2
|
|||||||
LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1
|
LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1
|
||||||
TIME_EMBED_FLIP_SIN_TO_COS: bool = True
|
TIME_EMBED_FLIP_SIN_TO_COS: bool = True
|
||||||
TIME_EMBED_FREQ_SHIFT: int = 0
|
TIME_EMBED_FREQ_SHIFT: int = 0
|
||||||
RESNET_GROUPS: int = 32
|
NORM_GROUPS: int = 32
|
||||||
RESNET_EPS: float = 1e-6
|
NORM_EPS: float = 1e-5
|
||||||
TRANSFORMER_NORM_NUM_GROUPS = 32
|
TRANSFORMER_NORM_NUM_GROUPS = 32
|
||||||
|
|
||||||
DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]
|
DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]
|
||||||
@@ -233,13 +234,13 @@ class ResnetBlock2D(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
|
||||||
self.norm1 = torch.nn.GroupNorm(num_groups=RESNET_GROUPS, num_channels=in_channels, eps=RESNET_EPS, affine=True)
|
self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True)
|
||||||
|
|
||||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels)
|
self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels)
|
||||||
|
|
||||||
self.norm2 = torch.nn.GroupNorm(num_groups=RESNET_GROUPS, num_channels=out_channels, eps=RESNET_EPS, affine=True)
|
self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True)
|
||||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
# if non_linearity == "swish":
|
# if non_linearity == "swish":
|
||||||
@@ -304,6 +305,9 @@ class DownBlock2D(nn.Module):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def set_use_memory_efficient_attention_xformers(self, value):
|
||||||
|
pass
|
||||||
|
|
||||||
def forward(self, hidden_states, temb=None):
|
def forward(self, hidden_states, temb=None):
|
||||||
output_states = ()
|
output_states = ()
|
||||||
|
|
||||||
@@ -372,6 +376,11 @@ class CrossAttention(nn.Module):
|
|||||||
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
||||||
# no dropout here
|
# no dropout here
|
||||||
|
|
||||||
|
self.use_memory_efficient_attention_xformers = False
|
||||||
|
|
||||||
|
def set_use_memory_efficient_attention_xformers(self, value):
|
||||||
|
self.use_memory_efficient_attention_xformers = value
|
||||||
|
|
||||||
def reshape_heads_to_batch_dim(self, tensor):
|
def reshape_heads_to_batch_dim(self, tensor):
|
||||||
batch_size, seq_len, dim = tensor.shape
|
batch_size, seq_len, dim = tensor.shape
|
||||||
head_size = self.heads
|
head_size = self.heads
|
||||||
@@ -387,6 +396,9 @@ class CrossAttention(nn.Module):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def forward(self, hidden_states, context=None, mask=None):
|
def forward(self, hidden_states, context=None, mask=None):
|
||||||
|
if self.use_memory_efficient_attention_xformers:
|
||||||
|
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
|
||||||
|
|
||||||
query = self.to_q(hidden_states)
|
query = self.to_q(hidden_states)
|
||||||
context = context if context is not None else hidden_states
|
context = context if context is not None else hidden_states
|
||||||
key = self.to_k(context)
|
key = self.to_k(context)
|
||||||
@@ -427,6 +439,30 @@ class CrossAttention(nn.Module):
|
|||||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
# TODO support Hypernetworks
|
||||||
|
def forward_memory_efficient_xformers(self, x, context=None, mask=None):
|
||||||
|
import xformers.ops
|
||||||
|
|
||||||
|
h = self.heads
|
||||||
|
q_in = self.to_q(x)
|
||||||
|
context = context if context is not None else x
|
||||||
|
context = context.to(x.dtype)
|
||||||
|
k_in = self.to_k(context)
|
||||||
|
v_in = self.to_v(context)
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
||||||
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
||||||
|
|
||||||
|
out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
||||||
|
|
||||||
|
out = self.to_out[0](out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
# feedforward
|
# feedforward
|
||||||
class GEGLU(nn.Module):
|
class GEGLU(nn.Module):
|
||||||
@@ -506,8 +542,9 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
# 3. Feed-forward
|
# 3. Feed-forward
|
||||||
self.norm3 = nn.LayerNorm(dim)
|
self.norm3 = nn.LayerNorm(dim)
|
||||||
|
|
||||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
def set_use_memory_efficient_attention_xformers(self, value: bool):
|
||||||
raise NotImplementedError("Memory efficient attention is not implemented for this model.")
|
self.attn1.set_use_memory_efficient_attention_xformers(value)
|
||||||
|
self.attn2.set_use_memory_efficient_attention_xformers(value)
|
||||||
|
|
||||||
def forward(self, hidden_states, context=None, timestep=None):
|
def forward(self, hidden_states, context=None, timestep=None):
|
||||||
# 1. Self-Attention
|
# 1. Self-Attention
|
||||||
@@ -566,6 +603,10 @@ class Transformer2DModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
|
||||||
|
def set_use_memory_efficient_attention_xformers(self, value):
|
||||||
|
for transformer in self.transformer_blocks:
|
||||||
|
transformer.set_use_memory_efficient_attention_xformers(value)
|
||||||
|
|
||||||
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
||||||
# 1. Input
|
# 1. Input
|
||||||
batch, _, height, weight = hidden_states.shape
|
batch, _, height, weight = hidden_states.shape
|
||||||
@@ -643,6 +684,10 @@ class CrossAttnDownBlock2D(nn.Module):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def set_use_memory_efficient_attention_xformers(self, value):
|
||||||
|
for attn in self.attentions:
|
||||||
|
attn.set_use_memory_efficient_attention_xformers(value)
|
||||||
|
|
||||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||||
output_states = ()
|
output_states = ()
|
||||||
|
|
||||||
@@ -714,11 +759,37 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|||||||
self.attentions = nn.ModuleList(attentions)
|
self.attentions = nn.ModuleList(attentions)
|
||||||
self.resnets = nn.ModuleList(resnets)
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def set_use_memory_efficient_attention_xformers(self, value):
|
||||||
|
for attn in self.attentions:
|
||||||
|
attn.set_use_memory_efficient_attention_xformers(value)
|
||||||
|
|
||||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||||
hidden_states = self.resnets[0](hidden_states, temb)
|
for i, resnet in enumerate(self.resnets):
|
||||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
attn = None if i == 0 else self.attentions[i - 1]
|
||||||
hidden_states = attn(hidden_states, encoder_hidden_states).sample
|
|
||||||
hidden_states = resnet(hidden_states, temb)
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module, return_dict=None):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
if return_dict is not None:
|
||||||
|
return module(*inputs, return_dict=return_dict)
|
||||||
|
else:
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
if attn is not None:
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||||
|
else:
|
||||||
|
if attn is not None:
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states).sample
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -792,6 +863,9 @@ class UpBlock2D(nn.Module):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def set_use_memory_efficient_attention_xformers(self, value):
|
||||||
|
pass
|
||||||
|
|
||||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
||||||
for resnet in self.resnets:
|
for resnet in self.resnets:
|
||||||
# pop res hidden states
|
# pop res hidden states
|
||||||
@@ -868,6 +942,10 @@ class CrossAttnUpBlock2D(nn.Module):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def set_use_memory_efficient_attention_xformers(self, value):
|
||||||
|
for attn in self.attentions:
|
||||||
|
attn.set_use_memory_efficient_attention_xformers(value)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -991,6 +1069,8 @@ class UNet2DConditionModel(nn.Module):
|
|||||||
|
|
||||||
self.sample_size = sample_size
|
self.sample_size = sample_size
|
||||||
|
|
||||||
|
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
|
||||||
|
|
||||||
# input
|
# input
|
||||||
self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1))
|
self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1))
|
||||||
|
|
||||||
@@ -1069,7 +1149,7 @@ class UNet2DConditionModel(nn.Module):
|
|||||||
prev_output_channel = output_channel
|
prev_output_channel = output_channel
|
||||||
|
|
||||||
# out
|
# out
|
||||||
self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=RESNET_GROUPS, eps=RESNET_EPS)
|
self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS)
|
||||||
self.conv_act = nn.SiLU()
|
self.conv_act = nn.SiLU()
|
||||||
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
||||||
|
|
||||||
@@ -1088,16 +1168,20 @@ class UNet2DConditionModel(nn.Module):
|
|||||||
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
||||||
|
|
||||||
def enable_gradient_checkpointing(self):
|
def enable_gradient_checkpointing(self):
|
||||||
self._set_gradient_checkpointing(self, value=True)
|
self.set_gradient_checkpointing(value=True)
|
||||||
|
|
||||||
def disable_gradient_checkpointing(self):
|
def disable_gradient_checkpointing(self):
|
||||||
self._set_gradient_checkpointing(self, value=False)
|
self.set_gradient_checkpointing(value=False)
|
||||||
|
|
||||||
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
|
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
|
||||||
raise NotImplementedError("Memory efficient attention is not supported for this model.")
|
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
||||||
|
for module in modules:
|
||||||
|
module.set_use_memory_efficient_attention_xformers(valid)
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
def set_gradient_checkpointing(self, value=False):
|
||||||
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
|
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
||||||
|
for module in modules:
|
||||||
|
print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
|
||||||
module.gradient_checkpointing = value
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|||||||
@@ -1792,7 +1792,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
|
|||||||
if mem_eff_attn:
|
if mem_eff_attn:
|
||||||
replace_unet_cross_attn_to_memory_efficient()
|
replace_unet_cross_attn_to_memory_efficient()
|
||||||
elif xformers:
|
elif xformers:
|
||||||
replace_unet_cross_attn_to_xformers()
|
replace_unet_cross_attn_to_xformers(unet)
|
||||||
|
|
||||||
|
|
||||||
def replace_unet_cross_attn_to_memory_efficient():
|
def replace_unet_cross_attn_to_memory_efficient():
|
||||||
@@ -1827,55 +1827,59 @@ def replace_unet_cross_attn_to_memory_efficient():
|
|||||||
|
|
||||||
out = rearrange(out, "b h n d -> b n (h d)")
|
out = rearrange(out, "b h n d -> b n (h d)")
|
||||||
|
|
||||||
# diffusers 0.7.0~ わざわざ変えるなよ (;´Д`)
|
|
||||||
out = self.to_out[0](out)
|
out = self.to_out[0](out)
|
||||||
out = self.to_out[1](out)
|
# out = self.to_out[1](out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
# diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
||||||
|
from library.original_unet import CrossAttention
|
||||||
|
|
||||||
|
CrossAttention.forward = forward_flash_attn
|
||||||
|
|
||||||
|
|
||||||
def replace_unet_cross_attn_to_xformers():
|
def replace_unet_cross_attn_to_xformers(unet):
|
||||||
print("CrossAttention.forward has been replaced to enable xformers.")
|
print("CrossAttention.forward has been replaced to enable xformers.")
|
||||||
try:
|
try:
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("No xformers / xformersがインストールされていないようです")
|
raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||||
|
|
||||||
def forward_xformers(self, x, context=None, mask=None):
|
unet.set_use_memory_efficient_attention_xformers(True)
|
||||||
h = self.heads
|
|
||||||
q_in = self.to_q(x)
|
|
||||||
|
|
||||||
context = default(context, x)
|
# def forward_xformers(self, x, context=None, mask=None):
|
||||||
context = context.to(x.dtype)
|
# h = self.heads
|
||||||
|
# q_in = self.to_q(x)
|
||||||
|
|
||||||
if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
# context = default(context, x)
|
||||||
context_k, context_v = self.hypernetwork.forward(x, context)
|
# context = context.to(x.dtype)
|
||||||
context_k = context_k.to(x.dtype)
|
|
||||||
context_v = context_v.to(x.dtype)
|
|
||||||
else:
|
|
||||||
context_k = context
|
|
||||||
context_v = context
|
|
||||||
|
|
||||||
k_in = self.to_k(context_k)
|
# if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
||||||
v_in = self.to_v(context_v)
|
# context_k, context_v = self.hypernetwork.forward(x, context)
|
||||||
|
# context_k = context_k.to(x.dtype)
|
||||||
|
# context_v = context_v.to(x.dtype)
|
||||||
|
# else:
|
||||||
|
# context_k = context
|
||||||
|
# context_v = context
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
# k_in = self.to_k(context_k)
|
||||||
del q_in, k_in, v_in
|
# v_in = self.to_v(context_v)
|
||||||
|
|
||||||
q = q.contiguous()
|
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
||||||
k = k.contiguous()
|
# del q_in, k_in, v_in
|
||||||
v = v.contiguous()
|
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
|
||||||
|
|
||||||
out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
# q = q.contiguous()
|
||||||
|
# k = k.contiguous()
|
||||||
|
# v = v.contiguous()
|
||||||
|
# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
||||||
|
|
||||||
# diffusers 0.7.0~
|
# out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
||||||
out = self.to_out[0](out)
|
|
||||||
out = self.to_out[1](out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
# # diffusers 0.7.0~
|
||||||
|
# out = self.to_out[0](out)
|
||||||
|
# out = self.to_out[1](out)
|
||||||
|
# return out
|
||||||
|
|
||||||
|
# diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user