From 46c38e0634e9f2bc4f63f98639bb6cd0dc30bf2d Mon Sep 17 00:00:00 2001 From: umisetokikaze Date: Wed, 11 Mar 2026 22:50:49 +0900 Subject: [PATCH] feat: update predict_noise_xl to use vector embedding from add_time_ids --- library/leco_train_util.py | 17 ++++++----- tests/library/test_leco_train_util.py | 43 +++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/library/leco_train_util.py b/library/leco_train_util.py index 14da3c3d..c2c3eda2 100644 --- a/library/leco_train_util.py +++ b/library/leco_train_util.py @@ -6,6 +6,7 @@ from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union import torch import yaml +from library import sdxl_train_util ResolutionValue = Union[int, Tuple[int, int]] @@ -461,7 +462,6 @@ def get_add_time_ids( add_time_ids = add_time_ids.to(device) return add_time_ids - def predict_noise_xl( unet, scheduler, @@ -473,13 +473,14 @@ def predict_noise_xl( ): latent_model_input = torch.cat([latents] * 2) latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) - added_cond_kwargs = {"text_embeds": prompt_embeds.pooled_embeds, "time_ids": add_time_ids} - noise_pred = unet( - latent_model_input, - timestep, - encoder_hidden_states=prompt_embeds.text_embeds, - added_cond_kwargs=added_cond_kwargs, - ).sample + + orig_size = add_time_ids[:, :2] + crop_size = add_time_ids[:, 2:4] + target_size = add_time_ids[:, 4:6] + size_embeddings = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, latent_model_input.device) + vector_embedding = torch.cat([prompt_embeds.pooled_embeds, size_embeddings.to(prompt_embeds.pooled_embeds.dtype)], dim=1) + + noise_pred = unet(latent_model_input, timestep, prompt_embeds.text_embeds, vector_embedding) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) diff --git a/tests/library/test_leco_train_util.py b/tests/library/test_leco_train_util.py index e575614f..12da6ef9 100644 --- a/tests/library/test_leco_train_util.py +++ b/tests/library/test_leco_train_util.py @@ -1,5 +1,7 @@ from pathlib import Path +import torch + from library.leco_train_util import load_prompt_settings @@ -69,3 +71,44 @@ neutral: "" assert fourth.action == "enhance" assert fourth.multiplier == -1.25 + + +def test_predict_noise_xl_uses_vector_embedding_from_add_time_ids(): + from library import sdxl_train_util + from library.leco_train_util import PromptEmbedsXL, predict_noise_xl + + class DummyScheduler: + def scale_model_input(self, latent_model_input, timestep): + return latent_model_input + + class DummyUNet: + def __call__(self, x, timesteps, context, y): + self.x = x + self.timesteps = timesteps + self.context = context + self.y = y + return torch.zeros_like(x) + + latents = torch.randn(1, 4, 8, 8) + prompt_embeds = PromptEmbedsXL( + text_embeds=torch.randn(2, 77, 2048), + pooled_embeds=torch.randn(2, 1280), + ) + add_time_ids = torch.tensor( + [ + [1024, 1024, 0, 0, 1024, 1024], + [1024, 1024, 0, 0, 1024, 1024], + ], + dtype=prompt_embeds.pooled_embeds.dtype, + ) + + unet = DummyUNet() + noise_pred = predict_noise_xl(unet, DummyScheduler(), torch.tensor(10), latents, prompt_embeds, add_time_ids) + + expected_size_embeddings = sdxl_train_util.get_size_embeddings( + add_time_ids[:, :2], add_time_ids[:, 2:4], add_time_ids[:, 4:6], latents.device + ).to(prompt_embeds.pooled_embeds.dtype) + + assert noise_pred.shape == latents.shape + assert unet.context is prompt_embeds.text_embeds + assert torch.equal(unet.y, torch.cat([prompt_embeds.pooled_embeds, expected_size_embeddings], dim=1))