From a38e255939b0013e5d75acd0bb777fbc5849401c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 8 May 2025 19:00:22 -0400 Subject: [PATCH] Fix partitioned packed latent values --- flux_train_network.py | 2 -- library/flux_train_utils.py | 36 ++++++++++++++++++++++++++---------- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 20ba321d..1b6c5533 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -363,8 +363,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 - packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - if args.partitioned_vae: packed_latent_height, packed_latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3] img_ids = flux_utils.prepare_paritioned_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 4537ee63..54f473ea 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -232,18 +232,30 @@ def sample_image_inference( # sample image weight_dtype = ae.dtype # TOFO give dtype as argument - packed_latent_height = height // 16 - packed_latent_width = width // 16 - noise = torch.randn( - 1, - packed_latent_height * packed_latent_width, - 16 * 2 * 2, + + # VAE 8x compression + latent_height = height // 8 + latent_width = width // 8 + + noisy_model_input = torch.randn( + 1, # Batch size + 16, # VAE channels + latent_height, + latent_width, device=accelerator.device, dtype=weight_dtype, generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, ) - timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True - img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) + + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + + if args.partitioned_vae: + packed_latent_height, packed_latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3] + img_ids = flux_utils.prepare_paritioned_img_ids(1, packed_latent_height, packed_latent_width).to(device=accelerator.device) + else: + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(device=accelerator.device) + timesteps = get_schedule(sample_steps, noisy_model_input.shape[1], shift=True) # FLUX.1 dev -> shift=True t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None if controlnet_image is not None: @@ -255,7 +267,7 @@ def sample_image_inference( with accelerator.autocast(), torch.no_grad(): x = denoise( flux, - noise, + packed_noisy_model_input, img_ids, t5_out, txt_ids, @@ -268,7 +280,11 @@ def sample_image_inference( neg_cond=neg_cond, ) - x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) + if args.partitioned_vae: + x = flux_utils.unpack_partitioned_latents(x, packed_latent_height, packed_latent_width) + else: + # unpack latents + x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) # latent to image clean_memory_on_device(accelerator.device)