diff --git a/flux_train_network.py b/flux_train_network.py index def44155..20ba321d 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -141,6 +141,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + # Apply partitioned for Diffusion4k + if args.partitioned_vae: + ae.decoder.partitioned = True + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model def get_tokenize_strategy(self, args): @@ -360,7 +364,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + if args.partitioned_vae: + packed_latent_height, packed_latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3] + img_ids = flux_utils.prepare_paritioned_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + else: + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) # get guidance # ensure guidance_scale in args is float @@ -408,7 +418,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ) # unpack latents - model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + if args.partitioned_vae: + model_pred = flux_utils.unpack_partitioned_latents(model_pred, packed_latent_height, packed_latent_width) + else: + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) # apply model prediction type model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) diff --git a/library/flux_models.py b/library/flux_models.py index 328ad481..2fbf0902 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -54,6 +54,8 @@ class AutoEncoderParams: z_channels: int scale_factor: float shift_factor: float + stride: int + partitioned: bool def swish(x: Tensor) -> Tensor: @@ -228,6 +230,8 @@ class Decoder(nn.Module): in_channels: int, resolution: int, z_channels: int, + partitioned=False, + stride=1, ): super().__init__() self.ch = ch @@ -236,6 +240,8 @@ class Decoder(nn.Module): self.resolution = resolution self.in_channels = in_channels self.ffactor = 2 ** (self.num_resolutions - 1) + self.stride = stride + self.partitioned = partitioned # compute in_ch_mult, block_in and curr_res at lowest res block_in = ch * ch_mult[self.num_resolutions - 1] @@ -272,7 +278,7 @@ class Decoder(nn.Module): self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - def forward(self, z: Tensor) -> Tensor: + def forward(self, z: Tensor, partitioned=None) -> Tensor: # z to block_in h = self.conv_in(z) @@ -291,9 +297,55 @@ class Decoder(nn.Module): h = self.up[i_level].upsample(h) # end - h = self.norm_out(h) - h = swish(h) - h = self.conv_out(h) + + # Diffusion4k + partitioned = partitioned if not None else self.partitioned + if self.stride > 1 and partitioned: + h = swish(h) + + overlap_size = 1 # because last conv kernel_size = 3 + res = [] + partitioned_height = h.shape[2] // self.stride + partitioned_width = h.shape[3] // self.stride + + assert self.stride == 2 # only support stride = 2 for now + rows = [] + for i in range(0, h.shape[2], partitioned_height): + row = [] + for j in range(0, h.shape[3], partitioned_width): + partition = h[:,:, max(i - overlap_size, 0) : min(i + partitioned_height + overlap_size, h.shape[2]), max(j - overlap_size, 0) : min(j + partitioned_width + overlap_size, h.shape[3])] + + # for strih + if i==0 and j==0: + partition = torch.nn.functional.pad(partition, (1, 0, 1, 0), "constant", 0) + elif i==0: + partition = torch.nn.functional.pad(partition, (0, 1, 1, 0), "constant", 0) + elif i>0 and j==0: + partition = torch.nn.functional.pad(partition, (1, 0, 0, 1), "constant", 0) + elif i>0 and j>0: + partition = torch.nn.functional.pad(partition, (0, 1, 0, 1), "constant", 0) + + partition = torch.nn.functional.interpolate(partition, scale_factor=self.stride, mode='nearest') + partition = self.conv_out(partition) + partition = partition[:,:,overlap_size:partitioned_height*2+overlap_size,overlap_size:partitioned_width*2+overlap_size] + + row.append(partition) + rows.append(row) + + for row in rows: + res.append(torch.cat(row, dim=3)) + + h = torch.cat(res, dim=2) + # Diffusion4k + elif self.stride > 1: + h = self.norm_out(h) + h = torch.nn.functional.interpolate(h, scale_factor=self.stride, mode='nearest') + h = swish(h) + h = self.conv_out(h) + else: + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) return h @@ -404,6 +456,9 @@ configs = { z_channels=16, scale_factor=0.3611, shift_factor=0.1159, + # Diffusion4k + stride=1, + partitioned=False, ), ), "schnell": ModelSpec( @@ -436,6 +491,9 @@ configs = { z_channels=16, scale_factor=0.3611, shift_factor=0.1159, + # Diffusion4k + stride=1, + partitioned=False, ), ), } diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 5f6867a8..4537ee63 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -680,3 +680,4 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): default=3.0, help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", ) + parser.add_argument("--partitioned_vae", action="store_true", help="Partitioned VAE from Diffusion4k paper") diff --git a/library/flux_utils.py b/library/flux_utils.py index 8be1d63e..bb4b4f2c 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -346,6 +346,13 @@ def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_wi img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size) return img_ids +def prepare_paritioned_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): + img_ids = torch.zeros(packed_latent_height // 2, packed_latent_width // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width // 2)[None, :] + img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + return img_ids + def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: """ @@ -354,6 +361,13 @@ def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_wid x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) return x +def unpack_partitioned_latents(x: torch.FloatTensor, packed_latent_height: int, packed_latent_width: int) -> torch.FloatTensor: + """ + x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height // 2, w=packed_latent_width // 2, ph=2, pw=2) + return x + def pack_latents(x: torch.Tensor) -> torch.Tensor: """