This commit is contained in:
sdbds
2025-02-15 16:38:59 +08:00
parent d154e76c45
commit c0caf33e3f
2 changed files with 171 additions and 12 deletions

View File

@@ -108,14 +108,6 @@ def load_gemma2(
logger.info(f"Loaded Gemma2: {info}")
return gemma2
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :]
img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
return img_ids
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
"""
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2