mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge pull request #961 from rockerBOO/attention-processor
Add attention processor
This commit is contained in:
@@ -586,6 +586,9 @@ class CrossAttention(nn.Module):
|
|||||||
self.use_memory_efficient_attention_mem_eff = False
|
self.use_memory_efficient_attention_mem_eff = False
|
||||||
self.use_sdpa = False
|
self.use_sdpa = False
|
||||||
|
|
||||||
|
# Attention processor
|
||||||
|
self.processor = None
|
||||||
|
|
||||||
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
||||||
self.use_memory_efficient_attention_xformers = xformers
|
self.use_memory_efficient_attention_xformers = xformers
|
||||||
self.use_memory_efficient_attention_mem_eff = mem_eff
|
self.use_memory_efficient_attention_mem_eff = mem_eff
|
||||||
@@ -607,7 +610,28 @@ class CrossAttention(nn.Module):
|
|||||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def forward(self, hidden_states, context=None, mask=None):
|
def set_processor(self):
|
||||||
|
return self.processor
|
||||||
|
|
||||||
|
def get_processor(self):
|
||||||
|
return self.processor
|
||||||
|
|
||||||
|
def forward(self, hidden_states, context=None, mask=None, **kwargs):
|
||||||
|
if self.processor is not None:
|
||||||
|
(
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
) = translate_attention_names_from_diffusers(
|
||||||
|
hidden_states=hidden_states, context=context, mask=mask, **kwargs
|
||||||
|
)
|
||||||
|
return self.processor(
|
||||||
|
attn=self,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
encoder_hidden_states=context,
|
||||||
|
attention_mask=mask,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
if self.use_memory_efficient_attention_xformers:
|
if self.use_memory_efficient_attention_xformers:
|
||||||
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
|
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
|
||||||
if self.use_memory_efficient_attention_mem_eff:
|
if self.use_memory_efficient_attention_mem_eff:
|
||||||
@@ -720,6 +744,21 @@ class CrossAttention(nn.Module):
|
|||||||
out = self.to_out[0](out)
|
out = self.to_out[0](out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def translate_attention_names_from_diffusers(
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
context: Optional[torch.FloatTensor] = None,
|
||||||
|
mask: Optional[torch.FloatTensor] = None,
|
||||||
|
# HF naming
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None
|
||||||
|
):
|
||||||
|
# translate from hugging face diffusers
|
||||||
|
context = context if context is not None else encoder_hidden_states
|
||||||
|
|
||||||
|
# translate from hugging face diffusers
|
||||||
|
mask = mask if mask is not None else attention_mask
|
||||||
|
|
||||||
|
return hidden_states, context, mask
|
||||||
|
|
||||||
# feedforward
|
# feedforward
|
||||||
class GEGLU(nn.Module):
|
class GEGLU(nn.Module):
|
||||||
@@ -1350,7 +1389,7 @@ class UNet2DConditionModel(nn.Module):
|
|||||||
self.out_channels = OUT_CHANNELS
|
self.out_channels = OUT_CHANNELS
|
||||||
|
|
||||||
self.sample_size = sample_size
|
self.sample_size = sample_size
|
||||||
self.prepare_config()
|
self.prepare_config(sample_size=sample_size)
|
||||||
|
|
||||||
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
|
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
|
||||||
|
|
||||||
@@ -1437,8 +1476,8 @@ class UNet2DConditionModel(nn.Module):
|
|||||||
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
||||||
|
|
||||||
# region diffusers compatibility
|
# region diffusers compatibility
|
||||||
def prepare_config(self):
|
def prepare_config(self, *args, **kwargs):
|
||||||
self.config = SimpleNamespace()
|
self.config = SimpleNamespace(**kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
|
|||||||
Reference in New Issue
Block a user