feat: split CFG processing in do_sample function to reduce memory usage

This commit is contained in:
Kohya S
2026-02-11 18:00:10 +09:00
parent 9349c91c89
commit a7cd38dcaf

View File

@@ -347,19 +347,13 @@ def do_sample(
sigma = sigmas[i]
t = sigma.unsqueeze(0) # (1,)
dit.prepare_block_swap_before_forward()
if use_cfg:
# CFG: concat positive and negative
x_input = torch.cat([x, x], dim=0)
t_input = torch.cat([t, t], dim=0)
crossattn_input = torch.cat([crossattn_emb, neg_crossattn_emb], dim=0)
padding_input = torch.cat([padding_mask, padding_mask], dim=0)
# CFG: two separate passes to reduce memory usage
pos_out = dit(x, t, crossattn_emb, padding_mask=padding_mask)
pos_out = pos_out.float()
neg_out = dit(x, t, neg_crossattn_emb, padding_mask=padding_mask)
neg_out = neg_out.float()
model_output = dit(x_input, t_input, crossattn_input, padding_mask=padding_input)
model_output = model_output.float()
pos_out, neg_out = model_output.chunk(2)
model_output = neg_out + guidance_scale * (pos_out - neg_out)
else:
model_output = dit(x, t, crossattn_emb, padding_mask=padding_mask)
@@ -370,7 +364,6 @@ def do_sample(
x = x + model_output * dt
x = x.to(dtype)
dit.prepare_block_swap_before_forward()
return x
@@ -577,6 +570,7 @@ def _sample_image_inference(
org_vae_device = vae.device
vae.to(accelerator.device)
decoded = vae.decode_to_pixels(latents)
input("Decoded")
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)