diff --git a/library/original_unet.py b/library/original_unet.py index 938b0b64..00997e7c 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -586,6 +586,9 @@ class CrossAttention(nn.Module): self.use_memory_efficient_attention_mem_eff = False self.use_sdpa = False + # Attention processor + self.processor = None + def set_use_memory_efficient_attention(self, xformers, mem_eff): self.use_memory_efficient_attention_xformers = xformers self.use_memory_efficient_attention_mem_eff = mem_eff @@ -607,7 +610,28 @@ class CrossAttention(nn.Module): tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor - def forward(self, hidden_states, context=None, mask=None): + def set_processor(self): + return self.processor + + def get_processor(self): + return self.processor + + def forward(self, hidden_states, context=None, mask=None, **kwargs): + if self.processor is not None: + ( + hidden_states, + encoder_hidden_states, + attention_mask, + ) = translate_attention_names_from_diffusers( + hidden_states=hidden_states, context=context, mask=mask, **kwargs + ) + return self.processor( + attn=self, + hidden_states=hidden_states, + encoder_hidden_states=context, + attention_mask=mask, + **kwargs + ) if self.use_memory_efficient_attention_xformers: return self.forward_memory_efficient_xformers(hidden_states, context, mask) if self.use_memory_efficient_attention_mem_eff: @@ -720,6 +744,21 @@ class CrossAttention(nn.Module): out = self.to_out[0](out) return out +def translate_attention_names_from_diffusers( + hidden_states: torch.FloatTensor, + context: Optional[torch.FloatTensor] = None, + mask: Optional[torch.FloatTensor] = None, + # HF naming + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None +): + # translate from hugging face diffusers + context = context if context is not None else encoder_hidden_states + + # translate from hugging face diffusers + mask = mask if mask is not None else attention_mask + + return hidden_states, context, mask # feedforward class GEGLU(nn.Module): @@ -1350,7 +1389,7 @@ class UNet2DConditionModel(nn.Module): self.out_channels = OUT_CHANNELS self.sample_size = sample_size - self.prepare_config() + self.prepare_config(sample_size=sample_size) # state_dictの書式が変わるのでmoduleの持ち方は変えられない @@ -1437,8 +1476,8 @@ class UNet2DConditionModel(nn.Module): self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) # region diffusers compatibility - def prepare_config(self): - self.config = SimpleNamespace() + def prepare_config(self, *args, **kwargs): + self.config = SimpleNamespace(**kwargs) @property def dtype(self) -> torch.dtype: