Merge pull request #961 from rockerBOO/attention-processor

Add attention processor
This commit is contained in:
Kohya S
2023-12-03 21:24:12 +09:00
committed by GitHub

View File

@@ -586,6 +586,9 @@ class CrossAttention(nn.Module):
self.use_memory_efficient_attention_mem_eff = False self.use_memory_efficient_attention_mem_eff = False
self.use_sdpa = False self.use_sdpa = False
# Attention processor
self.processor = None
def set_use_memory_efficient_attention(self, xformers, mem_eff): def set_use_memory_efficient_attention(self, xformers, mem_eff):
self.use_memory_efficient_attention_xformers = xformers self.use_memory_efficient_attention_xformers = xformers
self.use_memory_efficient_attention_mem_eff = mem_eff self.use_memory_efficient_attention_mem_eff = mem_eff
@@ -607,7 +610,28 @@ class CrossAttention(nn.Module):
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor return tensor
def forward(self, hidden_states, context=None, mask=None): def set_processor(self):
return self.processor
def get_processor(self):
return self.processor
def forward(self, hidden_states, context=None, mask=None, **kwargs):
if self.processor is not None:
(
hidden_states,
encoder_hidden_states,
attention_mask,
) = translate_attention_names_from_diffusers(
hidden_states=hidden_states, context=context, mask=mask, **kwargs
)
return self.processor(
attn=self,
hidden_states=hidden_states,
encoder_hidden_states=context,
attention_mask=mask,
**kwargs
)
if self.use_memory_efficient_attention_xformers: if self.use_memory_efficient_attention_xformers:
return self.forward_memory_efficient_xformers(hidden_states, context, mask) return self.forward_memory_efficient_xformers(hidden_states, context, mask)
if self.use_memory_efficient_attention_mem_eff: if self.use_memory_efficient_attention_mem_eff:
@@ -720,6 +744,21 @@ class CrossAttention(nn.Module):
out = self.to_out[0](out) out = self.to_out[0](out)
return out return out
def translate_attention_names_from_diffusers(
hidden_states: torch.FloatTensor,
context: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
# HF naming
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None
):
# translate from hugging face diffusers
context = context if context is not None else encoder_hidden_states
# translate from hugging face diffusers
mask = mask if mask is not None else attention_mask
return hidden_states, context, mask
# feedforward # feedforward
class GEGLU(nn.Module): class GEGLU(nn.Module):
@@ -1350,7 +1389,7 @@ class UNet2DConditionModel(nn.Module):
self.out_channels = OUT_CHANNELS self.out_channels = OUT_CHANNELS
self.sample_size = sample_size self.sample_size = sample_size
self.prepare_config() self.prepare_config(sample_size=sample_size)
# state_dictの書式が変わるのでmoduleの持ち方は変えられない # state_dictの書式が変わるのでmoduleの持ち方は変えられない
@@ -1437,8 +1476,8 @@ class UNet2DConditionModel(nn.Module):
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
# region diffusers compatibility # region diffusers compatibility
def prepare_config(self): def prepare_config(self, *args, **kwargs):
self.config = SimpleNamespace() self.config = SimpleNamespace(**kwargs)
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype: