Fix sample image gen to work with block swap

This commit is contained in:
Kohya S
2024-10-28 22:08:57 +09:00
parent 1065dd1b56
commit af8e216035

View File

@@ -364,6 +364,7 @@ def do_sample(
x_c_nc = torch.cat([x, x], dim=0) x_c_nc = torch.cat([x, x], dim=0)
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
mmdit.prepare_block_swap_before_forward()
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
model_output = model_output.float() model_output = model_output.float()
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
@@ -385,6 +386,7 @@ def do_sample(
x = x + d * dt x = x + d * dt
x = x.to(dtype) x = x.to(dtype)
mmdit.prepare_block_swap_before_forward()
return x return x