Add partitioned VAE

This commit is contained in:
rockerBOO
2025-05-07 23:09:27 -04:00
parent 80320d21fe
commit 9b35ef6dc9
4 changed files with 93 additions and 6 deletions

View File

@@ -141,6 +141,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
# Apply partitioned for Diffusion4k
if args.partitioned_vae:
ae.decoder.partitioned = True
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
def get_tokenize_strategy(self, args):
@@ -360,7 +364,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
if args.partitioned_vae:
packed_latent_height, packed_latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3]
img_ids = flux_utils.prepare_paritioned_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
else:
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
# get guidance
# ensure guidance_scale in args is float
@@ -408,7 +418,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
)
# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
if args.partitioned_vae:
model_pred = flux_utils.unpack_partitioned_latents(model_pred, packed_latent_height, packed_latent_width)
else:
# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)

View File

@@ -54,6 +54,8 @@ class AutoEncoderParams:
z_channels: int
scale_factor: float
shift_factor: float
stride: int
partitioned: bool
def swish(x: Tensor) -> Tensor:
@@ -228,6 +230,8 @@ class Decoder(nn.Module):
in_channels: int,
resolution: int,
z_channels: int,
partitioned=False,
stride=1,
):
super().__init__()
self.ch = ch
@@ -236,6 +240,8 @@ class Decoder(nn.Module):
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
self.stride = stride
self.partitioned = partitioned
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
@@ -272,7 +278,7 @@ class Decoder(nn.Module):
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z: Tensor) -> Tensor:
def forward(self, z: Tensor, partitioned=None) -> Tensor:
# z to block_in
h = self.conv_in(z)
@@ -291,9 +297,55 @@ class Decoder(nn.Module):
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
# Diffusion4k
partitioned = partitioned if not None else self.partitioned
if self.stride > 1 and partitioned:
h = swish(h)
overlap_size = 1 # because last conv kernel_size = 3
res = []
partitioned_height = h.shape[2] // self.stride
partitioned_width = h.shape[3] // self.stride
assert self.stride == 2 # only support stride = 2 for now
rows = []
for i in range(0, h.shape[2], partitioned_height):
row = []
for j in range(0, h.shape[3], partitioned_width):
partition = h[:,:, max(i - overlap_size, 0) : min(i + partitioned_height + overlap_size, h.shape[2]), max(j - overlap_size, 0) : min(j + partitioned_width + overlap_size, h.shape[3])]
# for strih
if i==0 and j==0:
partition = torch.nn.functional.pad(partition, (1, 0, 1, 0), "constant", 0)
elif i==0:
partition = torch.nn.functional.pad(partition, (0, 1, 1, 0), "constant", 0)
elif i>0 and j==0:
partition = torch.nn.functional.pad(partition, (1, 0, 0, 1), "constant", 0)
elif i>0 and j>0:
partition = torch.nn.functional.pad(partition, (0, 1, 0, 1), "constant", 0)
partition = torch.nn.functional.interpolate(partition, scale_factor=self.stride, mode='nearest')
partition = self.conv_out(partition)
partition = partition[:,:,overlap_size:partitioned_height*2+overlap_size,overlap_size:partitioned_width*2+overlap_size]
row.append(partition)
rows.append(row)
for row in rows:
res.append(torch.cat(row, dim=3))
h = torch.cat(res, dim=2)
# Diffusion4k
elif self.stride > 1:
h = self.norm_out(h)
h = torch.nn.functional.interpolate(h, scale_factor=self.stride, mode='nearest')
h = swish(h)
h = self.conv_out(h)
else:
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
@@ -404,6 +456,9 @@ configs = {
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
# Diffusion4k
stride=1,
partitioned=False,
),
),
"schnell": ModelSpec(
@@ -436,6 +491,9 @@ configs = {
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
# Diffusion4k
stride=1,
partitioned=False,
),
),
}

View File

@@ -680,3 +680,4 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)
parser.add_argument("--partitioned_vae", action="store_true", help="Partitioned VAE from Diffusion4k paper")

View File

@@ -346,6 +346,13 @@ def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_wi
img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
return img_ids
def prepare_paritioned_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
img_ids = torch.zeros(packed_latent_height // 2, packed_latent_width // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width // 2)[None, :]
img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
return img_ids
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
"""
@@ -354,6 +361,13 @@ def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_wid
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
return x
def unpack_partitioned_latents(x: torch.FloatTensor, packed_latent_height: int, packed_latent_width: int) -> torch.FloatTensor:
"""
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height // 2, w=packed_latent_width // 2, ph=2, pw=2)
return x
def pack_latents(x: torch.Tensor) -> torch.Tensor:
"""