diff --git a/flux_train_network.py b/flux_train_network.py index e1eec00d..96ed4b70 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -371,7 +371,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): else: img_ids = flux_utils.prepare_img_ids(bsz, latent_height // 2, latent_width // 2).to(device=accelerator.device) - assert packed_noisy_model_input.shape[2] * packed_noisy_model_input.shape[3] == img_ids.shape[1], "Packed latent dimensions are not aligned with img ids" + assert packed_noisy_model_input.shape[1] == img_ids.shape[1], "Packed latent dimensions are not aligned with img ids" # get guidance # ensure guidance_scale in args is float diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index ed42aaf0..cb958304 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -272,7 +272,7 @@ def sample_image_inference( latent_height, latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3] img_ids = flux_utils.prepare_img_ids(1, latent_height // 2, latent_width // 2).to(device=accelerator.device) - assert packed_noisy_model_input.shape[2] * packed_noisy_model_input.shape[3] == img_ids.shape[1], "Packed latent dimensions are not aligned with img ids" + assert packed_noisy_model_input.shape[1] == img_ids.shape[1], "Packed latent dimensions are not aligned with img ids" timesteps = get_schedule(sample_steps, noisy_model_input.shape[1], shift=True) # FLUX.1 dev -> shift=True t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None