Fixed packed noisy model input shape

This commit is contained in:
rockerBOO
2025-05-19 17:58:38 -04:00
parent 0a4c309def
commit a233ca07f0
2 changed files with 2 additions and 2 deletions

View File

@@ -371,7 +371,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
else:
img_ids = flux_utils.prepare_img_ids(bsz, latent_height // 2, latent_width // 2).to(device=accelerator.device)
assert packed_noisy_model_input.shape[2] * packed_noisy_model_input.shape[3] == img_ids.shape[1], "Packed latent dimensions are not aligned with img ids"
assert packed_noisy_model_input.shape[1] == img_ids.shape[1], "Packed latent dimensions are not aligned with img ids"
# get guidance
# ensure guidance_scale in args is float

View File

@@ -272,7 +272,7 @@ def sample_image_inference(
latent_height, latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3]
img_ids = flux_utils.prepare_img_ids(1, latent_height // 2, latent_width // 2).to(device=accelerator.device)
assert packed_noisy_model_input.shape[2] * packed_noisy_model_input.shape[3] == img_ids.shape[1], "Packed latent dimensions are not aligned with img ids"
assert packed_noisy_model_input.shape[1] == img_ids.shape[1], "Packed latent dimensions are not aligned with img ids"
timesteps = get_schedule(sample_steps, noisy_model_input.shape[1], shift=True) # FLUX.1 dev -> shift=True
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None