diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 27bd7460..4d8121ca 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -78,7 +78,7 @@ from diffusers import ( HeunDiscreteScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, - UNet2DConditionModel, + # UNet2DConditionModel, StableDiffusionPipeline, ) from einops import rearrange @@ -95,6 +95,7 @@ import library.train_util as train_util from networks.lora import LoRANetwork import tools.original_control_net as original_control_net from tools.original_control_net import ControlNetInfo +from library.original_unet import UNet2DConditionModel from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI diff --git a/library/model_util.py b/library/model_util.py index 26f72235..0fbc6590 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -5,8 +5,9 @@ import math import os import torch from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging -from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline #, UNet2DConditionModel from safetensors.torch import load_file, save_file +from library.original_unet import UNet2DConditionModel # DiffUsers版StableDiffusionのモデルパラメータ NUM_TRAIN_TIMESTEPS = 1000 diff --git a/library/original_unet.py b/library/original_unet.py new file mode 100644 index 00000000..603239d5 --- /dev/null +++ b/library/original_unet.py @@ -0,0 +1,1234 @@ +# Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる +# コードの多くはDiffusersからコピーしている +# コードが冗長になる部分はコメント等を適宜削除する +# 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある + +# Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers. +# Remove redundant code by deleting comments, etc. as appropriate +# As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2 + +""" +v1.5とv2.1の相違点は +- attention_head_dimがintかlist[int]か +- cross_attention_dimが768か1024か +- use_linear_projection: trueがない(=False, 1.5)かあるか +- upcast_attentionがFalse(1.5)かTrue(2.1)か +- (以下は多分無視していい) +- sample_sizeが64か96か +- dual_cross_attentionがあるかないか +- num_class_embedsがあるかないか +- only_cross_attentionがあるかないか + +v1.5 +{ + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.6.0", + "act_fn": "silu", + "attention_head_dim": 8, + "block_out_channels": [ + 320, + 640, + 1280, + 1280 + ], + "center_input_sample": false, + "cross_attention_dim": 768, + "down_block_types": [ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D" + ], + "downsample_padding": 1, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "out_channels": 4, + "sample_size": 64, + "up_block_types": [ + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D" + ] +} + +v2.1 +{ + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.10.0.dev0", + "act_fn": "silu", + "attention_head_dim": [ + 5, + 10, + 20, + 20 + ], + "block_out_channels": [ + 320, + 640, + 1280, + 1280 + ], + "center_input_sample": false, + "cross_attention_dim": 1024, + "down_block_types": [ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D" + ], + "downsample_padding": 1, + "dual_cross_attention": false, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_class_embeds": null, + "only_cross_attention": false, + "out_channels": 4, + "sample_size": 96, + "up_block_types": [ + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D" + ], + "use_linear_projection": true, + "upcast_attention": true +} +""" + +import math +from typing import Dict, Optional, Tuple, Union +import torch +from torch import nn +from torch.nn import functional as F + +BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) +TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0] +TIME_EMBED_DIM = BLOCK_OUT_CHANNELS[0] * 4 +IN_CHANNELS: int = 4 +OUT_CHANNELS: int = 4 +LAYERS_PER_BLOCK: int = 2 +LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1 +TIME_EMBED_FLIP_SIN_TO_COS: bool = True +TIME_EMBED_FREQ_SHIFT: int = 0 +RESNET_GROUPS: int = 32 +RESNET_EPS: float = 1e-6 +TRANSFORMER_NORM_NUM_GROUPS = 32 + +DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"] +UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"] + + +def get_parameter_dtype(parameter: torch.nn.Module): + return next(parameter.parameters()).dtype + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class SampleOutput: + def __init__(self, sample): + self.sample = sample + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + self.act = None + if act_fn == "silu": + self.act = nn.SiLU() + elif act_fn == "mish": + self.act = nn.Mish() + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + def forward(self, sample): + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class ResnetBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.norm1 = torch.nn.GroupNorm(num_groups=RESNET_GROUPS, num_channels=in_channels, eps=RESNET_EPS, affine=True) + + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels) + + self.norm2 = torch.nn.GroupNorm(num_groups=RESNET_GROUPS, num_channels=out_channels, eps=RESNET_EPS, affine=True) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + # if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + + self.use_in_shortcut = self.in_channels != self.out_channels + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + add_downsample=True, + ): + super().__init__() + + self.has_cross_attention = False + resnets = [] + + for i in range(LAYERS_PER_BLOCK): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)] + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class Downsample2D(nn.Module): + def __init__(self, channels, out_channels): + super().__init__() + + self.channels = channels + self.out_channels = out_channels + + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1) + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class CrossAttention(nn.Module): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + upcast_attention: bool = False, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + # no dropout here + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def forward(self, hidden_states, context=None, mask=None): + query = self.to_q(hidden_states) + context = context if context is not None else hidden_states + key = self.to_k(context) + value = self.to_v(context) + + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + hidden_states = self._attention(query, key, value) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # hidden_states = self.to_out[1](hidden_states) # no dropout + return hidden_states + + def _attention(self, query, key, value): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +# feedforward +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + ): + super().__init__() + inner_dim = int(dim * 4) # mult is always 4 + + self.net = nn.ModuleList([]) + # project in + self.net.append(GEGLU(dim, inner_dim)) + # project dropout + self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0 + # project out + self.net.append(nn.Linear(inner_dim, dim)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False + ): + super().__init__() + + # 1. Self-Attn + self.attn1 = CrossAttention( + query_dim=dim, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + upcast_attention=upcast_attention, + ) + self.ff = FeedForward(dim) + + # 2. Cross-Attn + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + upcast_attention=upcast_attention, + ) + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + raise NotImplementedError("Memory efficient attention is not implemented for this model.") + + def forward(self, hidden_states, context=None, timestep=None): + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + hidden_states = self.attn1(norm_hidden_states) + hidden_states + + # 2. Cross-Attention + norm_hidden_states = self.norm2(hidden_states) + hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + return hidden_states + + +class Transformer2DModel(nn.Module): + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + use_linear_projection: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.use_linear_projection = use_linear_projection + + self.norm = torch.nn.GroupNorm(num_groups=TRANSFORMER_NORM_NUM_GROUPS, num_channels=in_channels, eps=1e-6, affine=True) + + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + ] + ) + + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # 1. Input + batch, _, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep) + + # 3. Output + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return SampleOutput(sample=output) + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + add_downsample=True, + cross_attention_dim=1280, + attn_num_head_channels=1, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + self.has_cross_attention = True + resnets = [] + attentions = [] + + self.attn_num_head_channels = attn_num_head_channels + + for i in range(LAYERS_PER_BLOCK): + in_channels = in_channels if i == 0 else out_channels + + resnets.append(ResnetBlock2D(in_channels=in_channels, out_channels=out_channels)) + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + attn_num_head_channels=1, + cross_attention_dim=1280, + use_linear_projection=False, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + # Middle block has two resnets and one attention + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + ), + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + ), + ] + attentions = [ + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + ) + ] + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class Upsample2D(nn.Module): + def __init__(self, channels, out_channels): + super().__init__() + self.channels = channels + self.out_channels = out_channels + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, hidden_states, output_size): + assert hidden_states.shape[1] == self.channels + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + add_upsample=True, + ): + super().__init__() + + self.has_cross_attention = False + resnets = [] + + for i in range(LAYERS_PER_BLOCK_UP): + res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + attn_num_head_channels=1, + cross_attention_dim=1280, + add_upsample=True, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(LAYERS_PER_BLOCK_UP): + res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + ) + ) + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +def get_down_block( + down_block_type, + in_channels, + out_channels, + add_downsample, + attn_num_head_channels, + cross_attention_dim, + use_linear_projection, + upcast_attention, +): + if down_block_type == "DownBlock2D": + return DownBlock2D( + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "CrossAttnDownBlock2D": + return CrossAttnDownBlock2D( + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + +def get_up_block( + up_block_type, + in_channels, + out_channels, + prev_output_channel, + add_upsample, + attn_num_head_channels, + cross_attention_dim=None, + use_linear_projection=False, + upcast_attention=False, +): + if up_block_type == "UpBlock2D": + return UpBlock2D( + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + add_upsample=add_upsample, + ) + elif up_block_type == "CrossAttnUpBlock2D": + return CrossAttnUpBlock2D( + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + attn_num_head_channels=attn_num_head_channels, + cross_attention_dim=cross_attention_dim, + add_upsample=add_upsample, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + +class UNet2DConditionModel(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + sample_size: Optional[int] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + cross_attention_dim: int = 1280, + use_linear_projection: bool = False, + upcast_attention: bool = False, + **kwargs, + ): + super().__init__() + assert sample_size is not None, "sample_size must be specified" + print( + f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}" + ) + + # 外部からの参照用に定義しておく + self.in_channels = IN_CHANNELS + self.out_channels = OUT_CHANNELS + + self.sample_size = sample_size + + # input + self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(BLOCK_OUT_CHANNELS[0], TIME_EMBED_FLIP_SIN_TO_COS, TIME_EMBED_FREQ_SHIFT) + + self.time_embedding = TimestepEmbedding(TIMESTEP_INPUT_DIM, TIME_EMBED_DIM) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * 4 + + # down + output_channel = BLOCK_OUT_CHANNELS[0] + for i, down_block_type in enumerate(DOWN_BLOCK_TYPES): + input_channel = output_channel + output_channel = BLOCK_OUT_CHANNELS[i] + is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1 + + down_block = get_down_block( + down_block_type, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + attn_num_head_channels=attention_head_dim[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=BLOCK_OUT_CHANNELS[-1], + attn_num_head_channels=attention_head_dim[-1], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(BLOCK_OUT_CHANNELS)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(UP_BLOCK_TYPES): + is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(BLOCK_OUT_CHANNELS) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + add_upsample=add_upsample, + attn_num_head_channels=reversed_attention_head_dim[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=RESNET_GROUPS, eps=RESNET_EPS) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) + + # region diffusers compatibility + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def set_attention_slice(self, slice_size): + raise NotImplementedError("Attention slicing is not supported for this model.") + + def is_gradient_checkpointing(self) -> bool: + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + + def enable_gradient_checkpointing(self): + self._set_gradient_checkpointing(self, value=True) + + def disable_gradient_checkpointing(self): + self._set_gradient_checkpointing(self, value=False) + + def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: + raise NotImplementedError("Memory efficient attention is not supported for this model.") + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): + module.gradient_checkpointing = value + + # endregion + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[Dict, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a dict instead of a plain tuple. + + Returns: + `SampleOutput` or `tuple`: + `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある + # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する + # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + # 64で割り切れないときはupsamplerにサイズを伝える + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + # logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 1. time + timesteps = timestep + timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + # timestepsは重みを含まないので常にfloat32のテンソルを返す + # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある + # time_projでキャストしておけばいいんじゃね? + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 + # まあこちらのほうがわかりやすいかもしれない + if downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection + + # if we have not reached the final block and need to forward the upsample size, we do it here + # 前述のように最後のブロック以外ではupsample_sizeを伝える + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return SampleOutput(sample=sample) + + def handle_unusual_timesteps(self, sample, timesteps): + r""" + timestampsがTensorでない場合、Tensorに変換する。またOnnx/Core MLと互換性のあるようにbatchサイズまでbroadcastする。 + """ + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + return timesteps