From af8e216035128767234163a24debf2f4df5aa36d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 28 Oct 2024 22:08:57 +0900 Subject: [PATCH] Fix sample image gen to work with block swap --- library/sd3_train_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index a0202ad4..054d1b4a 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -364,6 +364,7 @@ def do_sample( x_c_nc = torch.cat([x, x], dim=0) # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + mmdit.prepare_block_swap_before_forward() model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) model_output = model_output.float() batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) @@ -385,6 +386,7 @@ def do_sample( x = x + d * dt x = x.to(dtype) + mmdit.prepare_block_swap_before_forward() return x