diff --git a/library/anima_train_utils.py b/library/anima_train_utils.py index cece4eb3..617f8d53 100644 --- a/library/anima_train_utils.py +++ b/library/anima_train_utils.py @@ -347,19 +347,13 @@ def do_sample( sigma = sigmas[i] t = sigma.unsqueeze(0) # (1,) - dit.prepare_block_swap_before_forward() - if use_cfg: - # CFG: concat positive and negative - x_input = torch.cat([x, x], dim=0) - t_input = torch.cat([t, t], dim=0) - crossattn_input = torch.cat([crossattn_emb, neg_crossattn_emb], dim=0) - padding_input = torch.cat([padding_mask, padding_mask], dim=0) + # CFG: two separate passes to reduce memory usage + pos_out = dit(x, t, crossattn_emb, padding_mask=padding_mask) + pos_out = pos_out.float() + neg_out = dit(x, t, neg_crossattn_emb, padding_mask=padding_mask) + neg_out = neg_out.float() - model_output = dit(x_input, t_input, crossattn_input, padding_mask=padding_input) - model_output = model_output.float() - - pos_out, neg_out = model_output.chunk(2) model_output = neg_out + guidance_scale * (pos_out - neg_out) else: model_output = dit(x, t, crossattn_emb, padding_mask=padding_mask) @@ -370,7 +364,6 @@ def do_sample( x = x + model_output * dt x = x.to(dtype) - dit.prepare_block_swap_before_forward() return x @@ -577,6 +570,7 @@ def _sample_image_inference( org_vae_device = vae.device vae.to(accelerator.device) decoded = vae.decode_to_pixels(latents) + input("Decoded") vae.to(org_vae_device) clean_memory_on_device(accelerator.device)