From a233ca07f043aa69160523316ead825ed2a9c399 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 19 May 2025 17:58:38 -0400 Subject: [PATCH] Fixed packed noisy model input shape --- flux_train_network.py | 2 +- library/flux_train_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index e1eec00d..96ed4b70 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -371,7 +371,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): else: img_ids = flux_utils.prepare_img_ids(bsz, latent_height // 2, latent_width // 2).to(device=accelerator.device) - assert packed_noisy_model_input.shape[2] * packed_noisy_model_input.shape[3] == img_ids.shape[1], "Packed latent dimensions are not aligned with img ids" + assert packed_noisy_model_input.shape[1] == img_ids.shape[1], "Packed latent dimensions are not aligned with img ids" # get guidance # ensure guidance_scale in args is float diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index ed42aaf0..cb958304 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -272,7 +272,7 @@ def sample_image_inference( latent_height, latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3] img_ids = flux_utils.prepare_img_ids(1, latent_height // 2, latent_width // 2).to(device=accelerator.device) - assert packed_noisy_model_input.shape[2] * packed_noisy_model_input.shape[3] == img_ids.shape[1], "Packed latent dimensions are not aligned with img ids" + assert packed_noisy_model_input.shape[1] == img_ids.shape[1], "Packed latent dimensions are not aligned with img ids" timesteps = get_schedule(sample_steps, noisy_model_input.shape[1], shift=True) # FLUX.1 dev -> shift=True t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None