Merge pull request #19 from rockerBOO/lumina-block-swap

Lumina block swap
This commit is contained in:
青龍聖者@bdsqlsz
2025-03-02 18:30:37 +08:00
committed by GitHub
9 changed files with 505 additions and 28 deletions

View File

@@ -604,7 +604,6 @@ def retrieve_timesteps(
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
def denoise(
scheduler,
model: lumina_models.NextDiT,
@@ -648,6 +647,8 @@ def denoise(
"""
for i, t in enumerate(tqdm(timesteps)):
model.prepare_block_swap_before_forward()
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
current_timestep = 1 - t / scheduler.config.num_train_timesteps
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@@ -664,6 +665,7 @@ def denoise(
# compute whether to apply classifier-free guidance based on current timestep
if current_timestep[0] < cfg_trunc_ratio:
model.prepare_block_swap_before_forward()
noise_pred_uncond = model(
img,
current_timestep,
@@ -702,6 +704,7 @@ def denoise(
noise_pred = -noise_pred
img = scheduler.step(noise_pred, t, img, return_dict=False)[0]
model.prepare_block_swap_before_forward()
return img