From af5f2c47b139ff7f52cbda81b3c6628d16f09332 Mon Sep 17 00:00:00 2001 From: umisetokikaze Date: Wed, 11 Mar 2026 22:59:38 +0900 Subject: [PATCH] feat: implement checkpointing in predict_noise and predict_noise_xl functions --- library/leco_train_util.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/library/leco_train_util.py b/library/leco_train_util.py index c2c3eda2..987f1233 100644 --- a/library/leco_train_util.py +++ b/library/leco_train_util.py @@ -5,6 +5,8 @@ from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union import torch import yaml +from torch.utils.checkpoint import checkpoint + from library import sdxl_train_util @@ -404,8 +406,6 @@ def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_p def concat_embeddings(unconditional: torch.Tensor, conditional: torch.Tensor, batch_size: int) -> torch.Tensor: return torch.cat([unconditional, conditional], dim=0).repeat_interleave(batch_size, dim=0) - - def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbedsXL, batch_size: int) -> PromptEmbedsXL: text_embeds = torch.cat([unconditional.text_embeds, conditional.text_embeds], dim=0).repeat_interleave(batch_size, dim=0) pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave( @@ -414,10 +414,20 @@ def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbed return PromptEmbedsXL(text_embeds=text_embeds, pooled_embeds=pooled_embeds) +def _run_with_checkpoint(function, *args): + if torch.is_grad_enabled(): + return checkpoint(function, *args, use_reentrant=False) + return function(*args) + + def predict_noise(unet, scheduler, timestep, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float = 1.0): latent_model_input = torch.cat([latents] * 2) latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) - noise_pred = unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample + + def run_unet(model_input, encoder_hidden_states): + return unet(model_input, timestep, encoder_hidden_states=encoder_hidden_states).sample + + noise_pred = _run_with_checkpoint(run_unet, latent_model_input, text_embeddings) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) @@ -462,6 +472,7 @@ def get_add_time_ids( add_time_ids = add_time_ids.to(device) return add_time_ids + def predict_noise_xl( unet, scheduler, @@ -480,11 +491,15 @@ def predict_noise_xl( size_embeddings = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, latent_model_input.device) vector_embedding = torch.cat([prompt_embeds.pooled_embeds, size_embeddings.to(prompt_embeds.pooled_embeds.dtype)], dim=1) - noise_pred = unet(latent_model_input, timestep, prompt_embeds.text_embeds, vector_embedding) + def run_unet(model_input, text_embeds, vector_embeds): + return unet(model_input, timestep, text_embeds, vector_embeds) + + noise_pred = _run_with_checkpoint(run_unet, latent_model_input, prompt_embeds.text_embeds, vector_embedding) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + def diffusion_xl( unet, scheduler,