Files
Kohya-ss-sd-scripts/library/stable_cascade.py
2024-02-17 23:59:20 +09:00

1374 lines
56 KiB
Python

# コードは Stable Cascade からコピーし、一部修正しています。元ライセンスは MIT です。
# The code is copied from Stable Cascade and modified. The original license is MIT.
# https://github.com/Stability-AI/StableCascade
import math
from typing import List, Optional
import numpy as np
import torch
import torch.nn as nn
import torchvision
# region VectorQuantize
# from torchtools https://github.com/pabloppp/pytorch-tools
# 依存ライブラリを増やしたくないのでここにコピペ
class vector_quantize(torch.autograd.Function):
@staticmethod
def forward(ctx, x, codebook):
with torch.no_grad():
codebook_sqr = torch.sum(codebook**2, dim=1)
x_sqr = torch.sum(x**2, dim=1, keepdim=True)
dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
_, indices = dist.min(dim=1)
ctx.save_for_backward(indices, codebook)
ctx.mark_non_differentiable(indices)
nn = torch.index_select(codebook, 0, indices)
return nn, indices
@staticmethod
def backward(ctx, grad_output, grad_indices):
grad_inputs, grad_codebook = None, None
if ctx.needs_input_grad[0]:
grad_inputs = grad_output.clone()
if ctx.needs_input_grad[1]:
# Gradient wrt. the codebook
indices, codebook = ctx.saved_tensors
grad_codebook = torch.zeros_like(codebook)
grad_codebook.index_add_(0, indices, grad_output)
return (grad_inputs, grad_codebook)
class VectorQuantize(nn.Module):
def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
"""
Takes an input of variable size (as long as the last dimension matches the embedding size).
Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
with the same size as the input, vq and commitment components for the loss as a touple
in the second output and the indices of the quantized vectors in the third:
quantized, (vq_loss, commit_loss), indices
"""
super(VectorQuantize, self).__init__()
self.codebook = nn.Embedding(k, embedding_size)
self.codebook.weight.data.uniform_(-1.0 / k, 1.0 / k)
self.vq = vector_quantize.apply
self.ema_decay = ema_decay
self.ema_loss = ema_loss
if ema_loss:
self.register_buffer("ema_element_count", torch.ones(k))
self.register_buffer("ema_weight_sum", torch.zeros_like(self.codebook.weight))
def _laplace_smoothing(self, x, epsilon):
n = torch.sum(x)
return (x + epsilon) / (n + x.size(0) * epsilon) * n
def _updateEMA(self, z_e_x, indices):
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
elem_count = mask.sum(dim=0)
weight_sum = torch.mm(mask.t(), z_e_x)
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1 - self.ema_decay) * elem_count)
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1 - self.ema_decay) * weight_sum)
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
def idx2vq(self, idx, dim=-1):
q_idx = self.codebook(idx)
if dim != -1:
q_idx = q_idx.movedim(-1, dim)
return q_idx
def forward(self, x, get_losses=True, dim=-1):
if dim != -1:
x = x.movedim(dim, -1)
z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
vq_loss, commit_loss = None, None
if self.ema_loss and self.training:
self._updateEMA(z_e_x.detach(), indices.detach())
# pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
if get_losses:
vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
z_q_x = z_q_x.view(x.shape)
if dim != -1:
z_q_x = z_q_x.movedim(-1, dim)
return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
# endregion
class EfficientNetEncoder(nn.Module):
def __init__(self, c_latent=16):
super().__init__()
self.backbone = torchvision.models.efficientnet_v2_s(weights="DEFAULT").features.eval()
self.mapper = nn.Sequential(
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
)
def forward(self, x):
return self.mapper(self.backbone(x))
# なんかわりと乱暴な実装(;'∀')
# 一から学習することもないだろうから、無効化しておく
# class Linear(torch.nn.Linear):
# def reset_parameters(self):
# return None
# class Conv2d(torch.nn.Conv2d):
# def reset_parameters(self):
# return None
from torch.nn import Conv2d
from torch.nn import Linear
class Attention2D(nn.Module):
def __init__(self, c, nhead, dropout=0.0):
super().__init__()
self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
def forward(self, x, kv, self_attn=False):
orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn:
kv = torch.cat([x, kv], dim=1)
x = self.attn(x, kv, kv, need_weights=False)[0]
x = x.permute(0, 2, 1).view(*orig_shape)
return x
class LayerNorm2d(nn.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x):
return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
class GlobalResponseNorm(nn.Module):
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
class ResBlock(nn.Module):
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): # , num_heads=4, expansion=2):
super().__init__()
self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
# self.depthwise = SAMBlock(c, num_heads, expansion)
self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
Linear(c + c_skip, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), Linear(c * 4, c)
)
def forward(self, x, x_skip=None):
x_res = x
x = self.norm(self.depthwise(x))
if x_skip is not None:
x = torch.cat([x, x_skip], dim=1)
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x + x_res
class AttnBlock(nn.Module):
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
super().__init__()
self.self_attn = self_attn
self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
self.attention = Attention2D(c, nhead, dropout)
self.kv_mapper = nn.Sequential(nn.SiLU(), Linear(c_cond, c))
def forward(self, x, kv):
kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
return x
class FeedForwardBlock(nn.Module):
def __init__(self, c, dropout=0.0):
super().__init__()
self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), Linear(c * 4, c)
)
def forward(self, x):
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
class TimestepBlock(nn.Module):
def __init__(self, c, c_timestep, conds=["sca"]):
super().__init__()
self.mapper = Linear(c_timestep, c * 2)
self.conds = conds
for cname in conds:
setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2))
def forward(self, x, t):
t = t.chunk(len(self.conds) + 1, dim=1)
a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
for i, c in enumerate(self.conds):
ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
a, b = a + ac, b + bc
return x * (1 + a) + b
class UpDownBlock2d(nn.Module):
def __init__(self, c_in, c_out, mode, enabled=True):
super().__init__()
assert mode in ["up", "down"]
interpolation = (
nn.Upsample(scale_factor=2 if mode == "up" else 0.5, mode="bilinear", align_corners=True) if enabled else nn.Identity()
)
mapping = nn.Conv2d(c_in, c_out, kernel_size=1)
self.blocks = nn.ModuleList([interpolation, mapping] if mode == "up" else [mapping, interpolation])
def forward(self, x):
for block in self.blocks:
x = block(x.float())
return x
class StageAResBlock(nn.Module):
def __init__(self, c, c_hidden):
super().__init__()
# depthwise/attention
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.depthwise = nn.Sequential(nn.ReplicationPad2d(1), nn.Conv2d(c, c, kernel_size=3, groups=c))
# channelwise
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
nn.Linear(c, c_hidden),
nn.GELU(),
nn.Linear(c_hidden, c),
)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
# Init weights
def _basic_init(module):
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def _norm(self, x, norm):
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
def forward(self, x):
mods = self.gammas
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
x = x + self.depthwise(x_temp) * mods[2]
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
return x
class StageA(nn.Module):
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, scale_factor=0.43): # 0.3764
super().__init__()
self.c_latent = c_latent
self.scale_factor = scale_factor
c_levels = [c_hidden // (2**i) for i in reversed(range(levels))]
# Encoder blocks
self.in_block = nn.Sequential(nn.PixelUnshuffle(2), nn.Conv2d(3 * 4, c_levels[0], kernel_size=1))
down_blocks = []
for i in range(levels):
if i > 0:
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
block = StageAResBlock(c_levels[i], c_levels[i] * 4)
down_blocks.append(block)
down_blocks.append(
nn.Sequential(
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
)
)
self.down_blocks = nn.Sequential(*down_blocks)
self.down_blocks[0]
self.codebook_size = codebook_size
self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
# Decoder blocks
up_blocks = [nn.Sequential(nn.Conv2d(c_latent, c_levels[-1], kernel_size=1))]
for i in range(levels):
for j in range(bottleneck_blocks if i == 0 else 1):
block = StageAResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
up_blocks.append(block)
if i < levels - 1:
up_blocks.append(
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1)
)
self.up_blocks = nn.Sequential(*up_blocks)
self.out_block = nn.Sequential(
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
nn.PixelShuffle(2),
)
def encode(self, x, quantize=False):
x = self.in_block(x)
x = self.down_blocks(x)
if quantize:
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
else:
return x / self.scale_factor, None, None, None
def decode(self, x):
x = x * self.scale_factor
x = self.up_blocks(x)
x = self.out_block(x)
return x
def forward(self, x, quantize=False):
qe, x, _, vq_loss = self.encode(x, quantize)
x = self.decode(qe)
return x, vq_loss
r"""
https://github.com/Stability-AI/StableCascade/blob/master/configs/inference/stage_b_3b.yaml
# GLOBAL STUFF
model_version: 3B
dtype: bfloat16
# For demonstration purposes in reconstruct_images.ipynb
webdataset_path: file:inference/imagenet_1024.tar
batch_size: 4
image_size: 1024
grad_accum_steps: 1
effnet_checkpoint_path: models/effnet_encoder.safetensors
stage_a_checkpoint_path: models/stage_a.safetensors
generator_checkpoint_path: models/stage_b_bf16.safetensors
"""
class StageB(nn.Module):
def __init__(
self,
c_in=4,
c_out=4,
c_r=64,
patch_size=2,
c_cond=1280,
c_hidden=[320, 640, 1280, 1280],
nhead=[-1, -1, 20, 20],
blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]],
level_config=["CT", "CT", "CTA", "CTA"],
c_clip=1280,
c_clip_seq=4,
c_effnet=16,
c_pixels=3,
kernel_size=3,
dropout=[0, 0, 0.1, 0.1],
self_attn=True,
t_conds=["sca"],
):
super().__init__()
self.c_r = c_r
self.t_conds = t_conds
self.c_clip_seq = c_clip_seq
if not isinstance(dropout, list):
dropout = [dropout] * len(c_hidden)
if not isinstance(self_attn, list):
self_attn = [self_attn] * len(c_hidden)
# CONDITIONING
self.effnet_mapper = nn.Sequential(
nn.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1),
nn.GELU(),
nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1),
LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
)
self.pixels_mapper = nn.Sequential(
nn.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1),
nn.GELU(),
nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1),
LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
)
self.clip_mapper = nn.Linear(c_clip, c_cond * c_clip_seq)
self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
self.embedding = nn.Sequential(
nn.PixelUnshuffle(patch_size),
nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1),
LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
)
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
if block_type == "C":
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
elif block_type == "A":
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout)
elif block_type == "F":
return FeedForwardBlock(c_hidden, dropout=dropout)
elif block_type == "T":
return TimestepBlock(c_hidden, c_r, conds=t_conds)
else:
raise Exception(f"Block type {block_type} not supported")
# BLOCKS
# -- down blocks
self.down_blocks = nn.ModuleList()
self.down_downscalers = nn.ModuleList()
self.down_repeat_mappers = nn.ModuleList()
for i in range(len(c_hidden)):
if i > 0:
self.down_downscalers.append(
nn.Sequential(
LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2),
)
)
else:
self.down_downscalers.append(nn.Identity())
down_block = nn.ModuleList()
for _ in range(blocks[0][i]):
for block_type in level_config[i]:
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
down_block.append(block)
self.down_blocks.append(down_block)
if block_repeat is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(block_repeat[0][i] - 1):
block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
self.down_repeat_mappers.append(block_repeat_mappers)
# -- up blocks
self.up_blocks = nn.ModuleList()
self.up_upscalers = nn.ModuleList()
self.up_repeat_mappers = nn.ModuleList()
for i in reversed(range(len(c_hidden))):
if i > 0:
self.up_upscalers.append(
nn.Sequential(
LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6),
nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2),
)
)
else:
self.up_upscalers.append(nn.Identity())
up_block = nn.ModuleList()
for j in range(blocks[1][::-1][i]):
for k, block_type in enumerate(level_config[i]):
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], self_attn=self_attn[i])
up_block.append(block)
self.up_blocks.append(up_block)
if block_repeat is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(block_repeat[1][::-1][i] - 1):
block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
self.up_repeat_mappers.append(block_repeat_mappers)
# OUTPUT
self.clf = nn.Sequential(
LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
nn.Conv2d(c_hidden[0], c_out * (patch_size**2), kernel_size=1),
nn.PixelShuffle(patch_size),
)
# --- WEIGHT INIT ---
self.apply(self._init_weights) # General init
nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
nn.init.constant_(self.clf[1].weight, 0) # outputs
# blocks
for level_block in self.down_blocks + self.up_blocks:
for block in level_block:
if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
elif isinstance(block, TimestepBlock):
for layer in block.modules():
if isinstance(layer, nn.Linear):
nn.init.constant_(layer.weight, 0)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
emb = r[:, None] * emb[None, :]
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
if self.c_r % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), mode="constant")
return emb
def gen_c_embeddings(self, clip):
if len(clip.shape) == 2:
clip = clip.unsqueeze(1)
clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
clip = self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
x = downscaler(x)
for i in range(len(repmap) + 1):
for block in down_block:
if isinstance(block, ResBlock) or (
hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock)
):
x = block(x)
elif isinstance(block, AttnBlock) or (
hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock)
):
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock)
):
x = block(x, r_embed)
else:
x = block(x)
if i < len(repmap):
x = repmap[i](x)
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, clip):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
for j in range(len(repmap) + 1):
for k, block in enumerate(up_block):
if isinstance(block, ResBlock) or (
hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock)
):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode="bilinear", align_corners=True)
x = block(x, skip)
elif isinstance(block, AttnBlock) or (
hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock)
):
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock)
):
x = block(x, r_embed)
else:
x = block(x)
if j < len(repmap):
x = repmap[j](x)
x = upscaler(x)
return x
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
if pixels is None:
pixels = x.new_zeros(x.size(0), 3, 8, 8)
# Process the conditioning embeddings
r_embed = self.gen_r_embedding(r)
for c in self.t_conds:
t_cond = kwargs.get(c, torch.zeros_like(r))
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1)
clip = self.gen_c_embeddings(clip)
# Model Blocks
x = self.embedding(x)
x = x + self.effnet_mapper(
nn.functional.interpolate(effnet.float(), size=x.shape[-2:], mode="bilinear", align_corners=True)
)
x = x + nn.functional.interpolate(
self.pixels_mapper(pixels).float(), size=x.shape[-2:], mode="bilinear", align_corners=True
)
level_outputs = self._down_encode(x, r_embed, clip)
x = self._up_decode(level_outputs, r_embed, clip)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
r"""
https://github.com/Stability-AI/StableCascade/blob/master/configs/inference/stage_c_3b.yaml
# GLOBAL STUFF
model_version: 3.6B
dtype: bfloat16
effnet_checkpoint_path: models/effnet_encoder.safetensors
previewer_checkpoint_path: models/previewer.safetensors
generator_checkpoint_path: models/stage_c_bf16.safetensors
"""
class StageC(nn.Module):
def __init__(
self,
c_in=16,
c_out=16,
c_r=64,
patch_size=1,
c_cond=2048,
c_hidden=[2048, 2048],
nhead=[32, 32],
blocks=[[8, 24], [24, 8]],
block_repeat=[[1, 1], [1, 1]],
level_config=["CTA", "CTA"],
c_clip_text=1280,
c_clip_text_pooled=1280,
c_clip_img=768,
c_clip_seq=4,
kernel_size=3,
dropout=[0.1, 0.1],
self_attn=True,
t_conds=["sca", "crp"],
switch_level=[False],
):
super().__init__()
self.c_r = c_r
self.t_conds = t_conds
self.c_clip_seq = c_clip_seq
if not isinstance(dropout, list):
dropout = [dropout] * len(c_hidden)
if not isinstance(self_attn, list):
self_attn = [self_attn] * len(c_hidden)
# CONDITIONING
self.clip_txt_mapper = nn.Linear(c_clip_text, c_cond)
self.clip_txt_pooled_mapper = nn.Linear(c_clip_text_pooled, c_cond * c_clip_seq)
self.clip_img_mapper = nn.Linear(c_clip_img, c_cond * c_clip_seq)
self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
self.embedding = nn.Sequential(
nn.PixelUnshuffle(patch_size),
nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1),
LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
)
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
if block_type == "C":
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
elif block_type == "A":
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout)
elif block_type == "F":
return FeedForwardBlock(c_hidden, dropout=dropout)
elif block_type == "T":
return TimestepBlock(c_hidden, c_r, conds=t_conds)
else:
raise Exception(f"Block type {block_type} not supported")
# BLOCKS
# -- down blocks
self.down_blocks = nn.ModuleList()
self.down_downscalers = nn.ModuleList()
self.down_repeat_mappers = nn.ModuleList()
for i in range(len(c_hidden)):
if i > 0:
self.down_downscalers.append(
nn.Sequential(
LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode="down", enabled=switch_level[i - 1]),
)
)
else:
self.down_downscalers.append(nn.Identity())
down_block = nn.ModuleList()
for _ in range(blocks[0][i]):
for block_type in level_config[i]:
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
down_block.append(block)
self.down_blocks.append(down_block)
if block_repeat is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(block_repeat[0][i] - 1):
block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
self.down_repeat_mappers.append(block_repeat_mappers)
# -- up blocks
self.up_blocks = nn.ModuleList()
self.up_upscalers = nn.ModuleList()
self.up_repeat_mappers = nn.ModuleList()
for i in reversed(range(len(c_hidden))):
if i > 0:
self.up_upscalers.append(
nn.Sequential(
LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6),
UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode="up", enabled=switch_level[i - 1]),
)
)
else:
self.up_upscalers.append(nn.Identity())
up_block = nn.ModuleList()
for j in range(blocks[1][::-1][i]):
for k, block_type in enumerate(level_config[i]):
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], self_attn=self_attn[i])
up_block.append(block)
self.up_blocks.append(up_block)
if block_repeat is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(block_repeat[1][::-1][i] - 1):
block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
self.up_repeat_mappers.append(block_repeat_mappers)
# OUTPUT
self.clf = nn.Sequential(
LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
nn.Conv2d(c_hidden[0], c_out * (patch_size**2), kernel_size=1),
nn.PixelShuffle(patch_size),
)
# --- WEIGHT INIT ---
self.apply(self._init_weights) # General init
nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
nn.init.constant_(self.clf[1].weight, 0) # outputs
# blocks
for level_block in self.down_blocks + self.up_blocks:
for block in level_block:
if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
elif isinstance(block, TimestepBlock):
for layer in block.modules():
if isinstance(layer, nn.Linear):
nn.init.constant_(layer.weight, 0)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
emb = r[:, None] * emb[None, :]
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
if self.c_r % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), mode="constant")
return emb
def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
clip_txt = self.clip_txt_mapper(clip_txt)
if len(clip_txt_pooled.shape) == 2:
clip_txt_pool = clip_txt_pooled.unsqueeze(1)
if len(clip_img.shape) == 2:
clip_img = clip_img.unsqueeze(1)
clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(
clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1
)
clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
clip = self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip, cnet=None):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
x = downscaler(x)
for i in range(len(repmap) + 1):
for block in down_block:
if isinstance(block, ResBlock) or (
hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock)
):
if cnet is not None:
next_cnet = cnet()
if next_cnet is not None:
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode="bilinear", align_corners=True)
x = block(x)
elif isinstance(block, AttnBlock) or (
hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock)
):
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock)
):
x = block(x, r_embed)
else:
x = block(x)
if i < len(repmap):
x = repmap[i](x)
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
for j in range(len(repmap) + 1):
for k, block in enumerate(up_block):
if isinstance(block, ResBlock) or (
hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock)
):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode="bilinear", align_corners=True)
if cnet is not None:
next_cnet = cnet()
if next_cnet is not None:
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode="bilinear", align_corners=True)
x = block(x, skip)
elif isinstance(block, AttnBlock) or (
hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock)
):
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock)
):
x = block(x, r_embed)
else:
x = block(x)
if j < len(repmap):
x = repmap[j](x)
x = upscaler(x)
return x
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs):
# Process the conditioning embeddings
r_embed = self.gen_r_embedding(r)
for c in self.t_conds:
t_cond = kwargs.get(c, torch.zeros_like(r))
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1)
clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
# Model Blocks
x = self.embedding(x)
# ControlNet is not supported yet
# if cnet is not None:
# cnet = ControlNetDeliverer(cnet)
level_outputs = self._down_encode(x, r_embed, clip, cnet)
x = self._up_decode(level_outputs, r_embed, clip, cnet)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
def get_clip_conditions(captions: Optional[List[str]], input_ids, tokenizer, text_model):
# self, batch: dict, tokenizer, text_model, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None
# is_eval の処理をここでやるのは微妙なので別のところでやる
# is_unconditional もここでやるのは微妙なので別のところでやる
# clip_image はとりあえずサポートしない
if captions is not None:
clip_tokens_unpooled = tokenizer(
captions, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
).to(text_model.device)
text_encoder_output = text_model(**clip_tokens_unpooled, output_hidden_states=True)
else:
text_encoder_output = text_model(input_ids, output_hidden_states=True)
text_embeddings = text_encoder_output.hidden_states[-1]
text_pooled_embeddings = text_encoder_output.text_embeds.unsqueeze(1)
return text_embeddings, text_pooled_embeddings
# return {"clip_text": text_embeddings, "clip_text_pooled": text_pooled_embeddings} # , "clip_img": image_embeddings}
def get_stage_c_conditions(
batch: dict, effnet, effnet_preprocess, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None
):
images = batch.get("images", None)
if images is not None:
images = images.to(self.device)
if is_eval and not is_unconditional:
effnet_embeddings = effnet(effnet_preprocess(images))
else:
if is_eval:
effnet_factor = 1
else:
effnet_factor = np.random.uniform(0.5, 1) # f64 to f32
effnet_height, effnet_width = int(((images.size(-2) * effnet_factor) // 32) * 32), int(
((images.size(-1) * effnet_factor) // 32) * 32
)
effnet_embeddings = torch.zeros(images.size(0), 16, effnet_height // 32, effnet_width // 32, device=self.device)
if not is_eval:
effnet_images = torchvision.transforms.functional.resize(
images, (effnet_height, effnet_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
rand_idx = np.random.rand(len(images)) <= 0.9
if any(rand_idx):
effnet_embeddings[rand_idx] = effnet(effnet_preprocess(effnet_images[rand_idx]))
else:
effnet_embeddings = None
return effnet_embeddings
# {"effnet": effnet_embeddings, "clip": conditions["clip_text_pooled"]}
# region gdf
class SimpleSampler:
def __init__(self, gdf):
self.gdf = gdf
self.current_step = -1
def __call__(self, *args, **kwargs):
self.current_step += 1
return self.step(*args, **kwargs)
def init_x(self, shape):
return torch.randn(*shape)
def step(self, x, x0, epsilon, logSNR, logSNR_prev):
raise NotImplementedError("You should override the 'apply' function.")
class DDIMSampler(SimpleSampler):
def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=0):
a, b = self.gdf.input_scaler(logSNR)
if len(a.shape) == 1:
a, b = a.view(-1, *[1] * (len(x0.shape) - 1)), b.view(-1, *[1] * (len(x0.shape) - 1))
a_prev, b_prev = self.gdf.input_scaler(logSNR_prev)
if len(a_prev.shape) == 1:
a_prev, b_prev = a_prev.view(-1, *[1] * (len(x0.shape) - 1)), b_prev.view(-1, *[1] * (len(x0.shape) - 1))
sigma_tau = eta * (b_prev**2 / b**2).sqrt() * (1 - a**2 / a_prev**2).sqrt() if eta > 0 else 0
# x = a_prev * x0 + (1 - a_prev**2 - sigma_tau ** 2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0)
x = a_prev * x0 + (b_prev**2 - sigma_tau**2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0)
return x
class DDPMSampler(DDIMSampler):
def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=1):
return super().step(x, x0, epsilon, logSNR, logSNR_prev, eta)
class LCMSampler(SimpleSampler):
def step(self, x, x0, epsilon, logSNR, logSNR_prev):
a_prev, b_prev = self.gdf.input_scaler(logSNR_prev)
if len(a_prev.shape) == 1:
a_prev, b_prev = a_prev.view(-1, *[1] * (len(x0.shape) - 1)), b_prev.view(-1, *[1] * (len(x0.shape) - 1))
return x0 * a_prev + torch.randn_like(epsilon) * b_prev
class GDF:
def __init__(self, schedule, input_scaler, target, noise_cond, loss_weight, offset_noise=0):
self.schedule = schedule
self.input_scaler = input_scaler
self.target = target
self.noise_cond = noise_cond
self.loss_weight = loss_weight
self.offset_noise = offset_noise
def setup_limits(self, stretch_max=True, stretch_min=True, shift=1):
stretched_limits = self.input_scaler.setup_limits(self.schedule, self.input_scaler, stretch_max, stretch_min, shift)
return stretched_limits
def diffuse(self, x0, epsilon=None, t=None, shift=1, loss_shift=1, offset=None):
if epsilon is None:
epsilon = torch.randn_like(x0)
if self.offset_noise > 0:
if offset is None:
offset = torch.randn([x0.size(0), x0.size(1)] + [1] * (len(x0.shape) - 2)).to(x0.device)
epsilon = epsilon + offset * self.offset_noise
logSNR = self.schedule(x0.size(0) if t is None else t, shift=shift).to(x0.device)
a, b = self.input_scaler(logSNR) # B
if len(a.shape) == 1:
a, b = a.view(-1, *[1] * (len(x0.shape) - 1)), b.view(-1, *[1] * (len(x0.shape) - 1)) # BxCxHxW
target = self.target(x0, epsilon, logSNR, a, b)
# noised, noise, logSNR, t_cond
return x0 * a + epsilon * b, epsilon, target, logSNR, self.noise_cond(logSNR), self.loss_weight(logSNR, shift=loss_shift)
def undiffuse(self, x, logSNR, pred):
a, b = self.input_scaler(logSNR)
if len(a.shape) == 1:
a, b = a.view(-1, *[1] * (len(x.shape) - 1)), b.view(-1, *[1] * (len(x.shape) - 1))
return self.target.x0(x, pred, logSNR, a, b), self.target.epsilon(x, pred, logSNR, a, b)
def sample(
self,
model,
model_inputs,
shape,
unconditional_inputs=None,
sampler=None,
schedule=None,
t_start=1.0,
t_end=0.0,
timesteps=20,
x_init=None,
cfg=3.0,
cfg_t_stop=None,
cfg_t_start=None,
cfg_rho=0.7,
sampler_params=None,
shift=1,
device="cpu",
):
sampler_params = {} if sampler_params is None else sampler_params
if sampler is None:
sampler = DDPMSampler(self)
r_range = torch.linspace(t_start, t_end, timesteps + 1)
schedule = self.schedule if schedule is None else schedule
logSNR_range = schedule(r_range, shift=shift)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(device)
x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone()
if cfg is not None:
if unconditional_inputs is None:
unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}
model_inputs = {
k: (
torch.cat([v, v_u], dim=0)
if isinstance(v, torch.Tensor)
else (
[
(
torch.cat([vi, vi_u], dim=0)
if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor)
else None
)
for vi, vi_u in zip(v, v_u)
]
if isinstance(v, list)
else (
{vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v}
if isinstance(v, dict)
else None
)
)
)
for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items())
}
for i in range(0, timesteps):
noise_cond = self.noise_cond(logSNR_range[i])
if (
cfg is not None
and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop)
and (cfg_t_start is None or r_range[i].item() <= cfg_t_start)
):
cfg_val = cfg
if isinstance(cfg_val, (list, tuple)):
assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2"
cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1 - r_range[i].item())
pred, pred_unconditional = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), **model_inputs).chunk(2)
pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val)
if cfg_rho > 0:
std_pos, std_cfg = pred.std(), pred_cfg.std()
pred = cfg_rho * (pred_cfg * std_pos / (std_cfg + 1e-9)) + pred_cfg * (1 - cfg_rho)
else:
pred = pred_cfg
else:
pred = model(x, noise_cond, **model_inputs)
x0, epsilon = self.undiffuse(x, logSNR_range[i], pred)
x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i + 1], **sampler_params)
altered_vars = yield (x0, x, pred)
# Update some running variables if the user wants
if altered_vars is not None:
cfg = altered_vars.get("cfg", cfg)
cfg_rho = altered_vars.get("cfg_rho", cfg_rho)
sampler = altered_vars.get("sampler", sampler)
model_inputs = altered_vars.get("model_inputs", model_inputs)
x = altered_vars.get("x", x)
x_init = altered_vars.get("x_init", x_init)
class BaseSchedule:
def __init__(self, *args, force_limits=True, discrete_steps=None, shift=1, **kwargs):
self.setup(*args, **kwargs)
self.limits = None
self.discrete_steps = discrete_steps
self.shift = shift
if force_limits:
self.reset_limits()
def reset_limits(self, shift=1, disable=False):
try:
self.limits = None if disable else self(torch.tensor([1.0, 0.0]), shift=shift).tolist() # min, max
return self.limits
except Exception:
print("WARNING: this schedule doesn't support t and will be unbounded")
return None
def setup(self, *args, **kwargs):
raise NotImplementedError("this method needs to be overriden")
def schedule(self, *args, **kwargs):
raise NotImplementedError("this method needs to be overriden")
def __call__(self, t, *args, shift=1, **kwargs):
if isinstance(t, torch.Tensor):
batch_size = None
if self.discrete_steps is not None:
if t.dtype != torch.long:
t = (t * (self.discrete_steps - 1)).round().long()
t = t / (self.discrete_steps - 1)
t = t.clamp(0, 1)
else:
batch_size = t
t = None
logSNR = self.schedule(t, batch_size, *args, **kwargs)
if shift * self.shift != 1:
logSNR += 2 * np.log(1 / (shift * self.shift))
if self.limits is not None:
logSNR = logSNR.clamp(*self.limits)
return logSNR
class CosineSchedule(BaseSchedule):
def setup(self, s=0.008, clamp_range=[0.0001, 0.9999], norm_instead=False):
self.s = torch.tensor([s])
self.clamp_range = clamp_range
self.norm_instead = norm_instead
self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
def schedule(self, t, batch_size):
if t is None:
t = (1 - torch.rand(batch_size)).add(0.001).clamp(0.001, 1.0)
s, min_var = self.s.to(t.device), self.min_var.to(t.device)
var = torch.cos((s + t) / (1 + s) * torch.pi * 0.5).clamp(0, 1) ** 2 / min_var
if self.norm_instead:
var = var * (self.clamp_range[1] - self.clamp_range[0]) + self.clamp_range[0]
else:
var = var.clamp(*self.clamp_range)
logSNR = (var / (1 - var)).log()
return logSNR
class BaseScaler:
def __init__(self):
self.stretched_limits = None
def setup_limits(self, schedule, input_scaler, stretch_max=True, stretch_min=True, shift=1):
min_logSNR = schedule(torch.ones(1), shift=shift)
max_logSNR = schedule(torch.zeros(1), shift=shift)
min_a, max_b = [v.item() for v in input_scaler(min_logSNR)] if stretch_max else [0, 1]
max_a, min_b = [v.item() for v in input_scaler(max_logSNR)] if stretch_min else [1, 0]
self.stretched_limits = [min_a, max_a, min_b, max_b]
return self.stretched_limits
def stretch_limits(self, a, b):
min_a, max_a, min_b, max_b = self.stretched_limits
return (a - min_a) / (max_a - min_a), (b - min_b) / (max_b - min_b)
def scalers(self, logSNR):
raise NotImplementedError("this method needs to be overridden")
def __call__(self, logSNR):
a, b = self.scalers(logSNR)
if self.stretched_limits is not None:
a, b = self.stretch_limits(a, b)
return a, b
class VPScaler(BaseScaler):
def scalers(self, logSNR):
a_squared = logSNR.sigmoid()
a = a_squared.sqrt()
b = (1 - a_squared).sqrt()
return a, b
class EpsilonTarget:
def __call__(self, x0, epsilon, logSNR, a, b):
return epsilon
def x0(self, noised, pred, logSNR, a, b):
return (noised - pred * b) / a
def epsilon(self, noised, pred, logSNR, a, b):
return pred
class BaseNoiseCond:
def __init__(self, *args, shift=1, clamp_range=None, **kwargs):
clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range
self.shift = shift
self.clamp_range = clamp_range
self.setup(*args, **kwargs)
def setup(self, *args, **kwargs):
pass # this method is optional, override it if required
def cond(self, logSNR):
raise NotImplementedError("this method needs to be overriden")
def __call__(self, logSNR):
if self.shift != 1:
logSNR = logSNR.clone() + 2 * np.log(self.shift)
return self.cond(logSNR).clamp(*self.clamp_range)
class CosineTNoiseCond(BaseNoiseCond):
def setup(self, s=0.008, clamp_range=[0, 1]): # [0.0001, 0.9999]
self.s = torch.tensor([s])
self.clamp_range = clamp_range
self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
def cond(self, logSNR):
var = logSNR.sigmoid()
var = var.clamp(*self.clamp_range)
s, min_var = self.s.to(var.device), self.min_var.to(var.device)
t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
return t
# --- Loss Weighting
class BaseLossWeight:
def weight(self, logSNR):
raise NotImplementedError("this method needs to be overridden")
def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs):
clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range
if shift != 1:
logSNR = logSNR.clone() + 2 * np.log(shift)
return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range)
# class ComposedLossWeight(BaseLossWeight):
# def __init__(self, div, mul):
# self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul
# self.div = [div] if isinstance(div, BaseLossWeight) else div
# def weight(self, logSNR):
# prod, div = 1, 1
# for m in self.mul:
# prod *= m.weight(logSNR)
# for d in self.div:
# div *= d.weight(logSNR)
# return prod/div
# class ConstantLossWeight(BaseLossWeight):
# def __init__(self, v=1):
# self.v = v
# def weight(self, logSNR):
# return torch.ones_like(logSNR) * self.v
# class SNRLossWeight(BaseLossWeight):
# def weight(self, logSNR):
# return logSNR.exp()
class P2LossWeight(BaseLossWeight):
def __init__(self, k=1.0, gamma=1.0, s=1.0):
self.k, self.gamma, self.s = k, gamma, s
def weight(self, logSNR):
return (self.k + (logSNR * self.s).exp()) ** -self.gamma
# class SNRPlusOneLossWeight(BaseLossWeight):
# def weight(self, logSNR):
# return logSNR.exp() + 1
# class MinSNRLossWeight(BaseLossWeight):
# def __init__(self, max_snr=5):
# self.max_snr = max_snr
# def weight(self, logSNR):
# return logSNR.exp().clamp(max=self.max_snr)
# class MinSNRPlusOneLossWeight(BaseLossWeight):
# def __init__(self, max_snr=5):
# self.max_snr = max_snr
# def weight(self, logSNR):
# return (logSNR.exp() + 1).clamp(max=self.max_snr)
# class TruncatedSNRLossWeight(BaseLossWeight):
# def __init__(self, min_snr=1):
# self.min_snr = min_snr
# def weight(self, logSNR):
# return logSNR.exp().clamp(min=self.min_snr)
# class SechLossWeight(BaseLossWeight):
# def __init__(self, div=2):
# self.div = div
# def weight(self, logSNR):
# return 1/(logSNR/self.div).cosh()
# class DebiasedLossWeight(BaseLossWeight):
# def weight(self, logSNR):
# return 1/logSNR.exp().sqrt()
# class SigmoidLossWeight(BaseLossWeight):
# def __init__(self, s=1):
# self.s = s
# def weight(self, logSNR):
# return (logSNR * self.s).sigmoid()
class AdaptiveLossWeight(BaseLossWeight):
def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]):
self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets - 1)
self.bucket_losses = torch.ones(buckets)
self.weight_range = weight_range
def weight(self, logSNR):
indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR)
return (1 / self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range)
def update_buckets(self, logSNR, loss, beta=0.99):
indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu()
self.bucket_losses[indices] = self.bucket_losses[indices] * beta + loss.detach().cpu() * (1 - beta)
# endregion gdf