feat: implement checkpointing in predict_noise and predict_noise_xl functions

This commit is contained in:
umisetokikaze
2026-03-11 22:59:38 +09:00
parent 9ffdd68c47
commit af5f2c47b1

View File

@@ -5,6 +5,8 @@ from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union
import torch
import yaml
from torch.utils.checkpoint import checkpoint
from library import sdxl_train_util
@@ -404,8 +406,6 @@ def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_p
def concat_embeddings(unconditional: torch.Tensor, conditional: torch.Tensor, batch_size: int) -> torch.Tensor:
return torch.cat([unconditional, conditional], dim=0).repeat_interleave(batch_size, dim=0)
def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbedsXL, batch_size: int) -> PromptEmbedsXL:
text_embeds = torch.cat([unconditional.text_embeds, conditional.text_embeds], dim=0).repeat_interleave(batch_size, dim=0)
pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave(
@@ -414,10 +414,20 @@ def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbed
return PromptEmbedsXL(text_embeds=text_embeds, pooled_embeds=pooled_embeds)
def _run_with_checkpoint(function, *args):
if torch.is_grad_enabled():
return checkpoint(function, *args, use_reentrant=False)
return function(*args)
def predict_noise(unet, scheduler, timestep, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float = 1.0):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
noise_pred = unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
def run_unet(model_input, encoder_hidden_states):
return unet(model_input, timestep, encoder_hidden_states=encoder_hidden_states).sample
noise_pred = _run_with_checkpoint(run_unet, latent_model_input, text_embeddings)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
@@ -462,6 +472,7 @@ def get_add_time_ids(
add_time_ids = add_time_ids.to(device)
return add_time_ids
def predict_noise_xl(
unet,
scheduler,
@@ -480,11 +491,15 @@ def predict_noise_xl(
size_embeddings = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, latent_model_input.device)
vector_embedding = torch.cat([prompt_embeds.pooled_embeds, size_embeddings.to(prompt_embeds.pooled_embeds.dtype)], dim=1)
noise_pred = unet(latent_model_input, timestep, prompt_embeds.text_embeds, vector_embedding)
def run_unet(model_input, text_embeds, vector_embeds):
return unet(model_input, timestep, text_embeds, vector_embeds)
noise_pred = _run_with_checkpoint(run_unet, latent_model_input, prompt_embeds.text_embeds, vector_embedding)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
def diffusion_xl(
unet,
scheduler,