mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge pull request #961 from rockerBOO/attention-processor
Add attention processor
This commit is contained in:
@@ -586,6 +586,9 @@ class CrossAttention(nn.Module):
|
||||
self.use_memory_efficient_attention_mem_eff = False
|
||||
self.use_sdpa = False
|
||||
|
||||
# Attention processor
|
||||
self.processor = None
|
||||
|
||||
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
||||
self.use_memory_efficient_attention_xformers = xformers
|
||||
self.use_memory_efficient_attention_mem_eff = mem_eff
|
||||
@@ -607,7 +610,28 @@ class CrossAttention(nn.Module):
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def forward(self, hidden_states, context=None, mask=None):
|
||||
def set_processor(self):
|
||||
return self.processor
|
||||
|
||||
def get_processor(self):
|
||||
return self.processor
|
||||
|
||||
def forward(self, hidden_states, context=None, mask=None, **kwargs):
|
||||
if self.processor is not None:
|
||||
(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
attention_mask,
|
||||
) = translate_attention_names_from_diffusers(
|
||||
hidden_states=hidden_states, context=context, mask=mask, **kwargs
|
||||
)
|
||||
return self.processor(
|
||||
attn=self,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=context,
|
||||
attention_mask=mask,
|
||||
**kwargs
|
||||
)
|
||||
if self.use_memory_efficient_attention_xformers:
|
||||
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
|
||||
if self.use_memory_efficient_attention_mem_eff:
|
||||
@@ -720,6 +744,21 @@ class CrossAttention(nn.Module):
|
||||
out = self.to_out[0](out)
|
||||
return out
|
||||
|
||||
def translate_attention_names_from_diffusers(
|
||||
hidden_states: torch.FloatTensor,
|
||||
context: Optional[torch.FloatTensor] = None,
|
||||
mask: Optional[torch.FloatTensor] = None,
|
||||
# HF naming
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None
|
||||
):
|
||||
# translate from hugging face diffusers
|
||||
context = context if context is not None else encoder_hidden_states
|
||||
|
||||
# translate from hugging face diffusers
|
||||
mask = mask if mask is not None else attention_mask
|
||||
|
||||
return hidden_states, context, mask
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
@@ -1350,7 +1389,7 @@ class UNet2DConditionModel(nn.Module):
|
||||
self.out_channels = OUT_CHANNELS
|
||||
|
||||
self.sample_size = sample_size
|
||||
self.prepare_config()
|
||||
self.prepare_config(sample_size=sample_size)
|
||||
|
||||
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
|
||||
|
||||
@@ -1437,8 +1476,8 @@ class UNet2DConditionModel(nn.Module):
|
||||
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
||||
|
||||
# region diffusers compatibility
|
||||
def prepare_config(self):
|
||||
self.config = SimpleNamespace()
|
||||
def prepare_config(self, *args, **kwargs):
|
||||
self.config = SimpleNamespace(**kwargs)
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
|
||||
Reference in New Issue
Block a user