From d6f7e2e20cfe91eb0c7a5f4c277107f7b699d97f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 28 Feb 2025 14:08:27 -0500 Subject: [PATCH] Fix block swap for sample images --- library/flux_train_utils.py | 1 - library/lumina_train_util.py | 3 ++- lumina_train_network.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f7f06c5c..c6d2baeb 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -317,7 +317,6 @@ def denoise( # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) - for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) model.prepare_block_swap_before_forward() diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 11dd3feb..e008b3ce 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -604,7 +604,6 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps - def denoise( scheduler, model: lumina_models.NextDiT, @@ -648,6 +647,7 @@ def denoise( """ for i, t in enumerate(tqdm(timesteps)): + # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image current_timestep = 1 - t / scheduler.config.num_train_timesteps # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -700,6 +700,7 @@ def denoise( noise_pred = -noise_pred img = scheduler.step(noise_pred, t, img, return_dict=False)[0] + model.prepare_block_swap_before_forward() return img diff --git a/lumina_train_network.py b/lumina_train_network.py index 3e003a92..60c39c20 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -367,6 +367,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() train_util.add_dit_training_arguments(parser)