mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Fix partitioned packed latent values
This commit is contained in:
@@ -363,8 +363,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
# pack latents and get img_ids
|
||||
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
||||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
||||
|
||||
if args.partitioned_vae:
|
||||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3]
|
||||
img_ids = flux_utils.prepare_paritioned_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
||||
|
||||
@@ -232,18 +232,30 @@ def sample_image_inference(
|
||||
|
||||
# sample image
|
||||
weight_dtype = ae.dtype # TOFO give dtype as argument
|
||||
packed_latent_height = height // 16
|
||||
packed_latent_width = width // 16
|
||||
noise = torch.randn(
|
||||
1,
|
||||
packed_latent_height * packed_latent_width,
|
||||
16 * 2 * 2,
|
||||
|
||||
# VAE 8x compression
|
||||
latent_height = height // 8
|
||||
latent_width = width // 8
|
||||
|
||||
noisy_model_input = torch.randn(
|
||||
1, # Batch size
|
||||
16, # VAE channels
|
||||
latent_height,
|
||||
latent_width,
|
||||
device=accelerator.device,
|
||||
dtype=weight_dtype,
|
||||
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
|
||||
)
|
||||
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
|
||||
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
||||
|
||||
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
||||
|
||||
if args.partitioned_vae:
|
||||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3]
|
||||
img_ids = flux_utils.prepare_paritioned_img_ids(1, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
||||
else:
|
||||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
||||
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
||||
timesteps = get_schedule(sample_steps, noisy_model_input.shape[1], shift=True) # FLUX.1 dev -> shift=True
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
|
||||
|
||||
if controlnet_image is not None:
|
||||
@@ -255,7 +267,7 @@ def sample_image_inference(
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
x = denoise(
|
||||
flux,
|
||||
noise,
|
||||
packed_noisy_model_input,
|
||||
img_ids,
|
||||
t5_out,
|
||||
txt_ids,
|
||||
@@ -268,7 +280,11 @@ def sample_image_inference(
|
||||
neg_cond=neg_cond,
|
||||
)
|
||||
|
||||
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
if args.partitioned_vae:
|
||||
x = flux_utils.unpack_partitioned_latents(x, packed_latent_height, packed_latent_width)
|
||||
else:
|
||||
# unpack latents
|
||||
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
|
||||
# latent to image
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
Reference in New Issue
Block a user