mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 00:32:25 +00:00
feat: split CFG processing in do_sample function to reduce memory usage
This commit is contained in:
@@ -347,19 +347,13 @@ def do_sample(
|
||||
sigma = sigmas[i]
|
||||
t = sigma.unsqueeze(0) # (1,)
|
||||
|
||||
dit.prepare_block_swap_before_forward()
|
||||
|
||||
if use_cfg:
|
||||
# CFG: concat positive and negative
|
||||
x_input = torch.cat([x, x], dim=0)
|
||||
t_input = torch.cat([t, t], dim=0)
|
||||
crossattn_input = torch.cat([crossattn_emb, neg_crossattn_emb], dim=0)
|
||||
padding_input = torch.cat([padding_mask, padding_mask], dim=0)
|
||||
# CFG: two separate passes to reduce memory usage
|
||||
pos_out = dit(x, t, crossattn_emb, padding_mask=padding_mask)
|
||||
pos_out = pos_out.float()
|
||||
neg_out = dit(x, t, neg_crossattn_emb, padding_mask=padding_mask)
|
||||
neg_out = neg_out.float()
|
||||
|
||||
model_output = dit(x_input, t_input, crossattn_input, padding_mask=padding_input)
|
||||
model_output = model_output.float()
|
||||
|
||||
pos_out, neg_out = model_output.chunk(2)
|
||||
model_output = neg_out + guidance_scale * (pos_out - neg_out)
|
||||
else:
|
||||
model_output = dit(x, t, crossattn_emb, padding_mask=padding_mask)
|
||||
@@ -370,7 +364,6 @@ def do_sample(
|
||||
x = x + model_output * dt
|
||||
x = x.to(dtype)
|
||||
|
||||
dit.prepare_block_swap_before_forward()
|
||||
return x
|
||||
|
||||
|
||||
@@ -577,6 +570,7 @@ def _sample_image_inference(
|
||||
org_vae_device = vae.device
|
||||
vae.to(accelerator.device)
|
||||
decoded = vae.decode_to_pixels(latents)
|
||||
input("Decoded")
|
||||
vae.to(org_vae_device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user