Fix partitioned packed latent values

This commit is contained in:
rockerBOO
2025-05-08 19:00:22 -04:00
parent 9b35ef6dc9
commit a38e255939
2 changed files with 26 additions and 12 deletions

View File

@@ -363,8 +363,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
if args.partitioned_vae:
packed_latent_height, packed_latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3]
img_ids = flux_utils.prepare_paritioned_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)

View File

@@ -232,18 +232,30 @@ def sample_image_inference(
# sample image
weight_dtype = ae.dtype # TOFO give dtype as argument
packed_latent_height = height // 16
packed_latent_width = width // 16
noise = torch.randn(
1,
packed_latent_height * packed_latent_width,
16 * 2 * 2,
# VAE 8x compression
latent_height = height // 8
latent_width = width // 8
noisy_model_input = torch.randn(
1, # Batch size
16, # VAE channels
latent_height,
latent_width,
device=accelerator.device,
dtype=weight_dtype,
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
)
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
if args.partitioned_vae:
packed_latent_height, packed_latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3]
img_ids = flux_utils.prepare_paritioned_img_ids(1, packed_latent_height, packed_latent_width).to(device=accelerator.device)
else:
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(device=accelerator.device)
timesteps = get_schedule(sample_steps, noisy_model_input.shape[1], shift=True) # FLUX.1 dev -> shift=True
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
if controlnet_image is not None:
@@ -255,7 +267,7 @@ def sample_image_inference(
with accelerator.autocast(), torch.no_grad():
x = denoise(
flux,
noise,
packed_noisy_model_input,
img_ids,
t5_out,
txt_ids,
@@ -268,7 +280,11 @@ def sample_image_inference(
neg_cond=neg_cond,
)
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
if args.partitioned_vae:
x = flux_utils.unpack_partitioned_latents(x, packed_latent_height, packed_latent_width)
else:
# unpack latents
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
# latent to image
clean_memory_on_device(accelerator.device)