mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
fix: improve pos_embed handling for oversized images and update resolution_area_to_latent_size, when sample image size > train image size
This commit is contained in:
@@ -1017,22 +1017,35 @@ class MMDiT(nn.Module):
|
||||
patched_size = patched_size_
|
||||
break
|
||||
if patched_size is None:
|
||||
raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")
|
||||
# raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")
|
||||
# use largest latent size
|
||||
patched_size = self.resolution_area_to_latent_size[-1][1]
|
||||
|
||||
pos_embed = self.resolution_pos_embeds[patched_size]
|
||||
pos_embed_size = round(math.sqrt(pos_embed.shape[1]))
|
||||
pos_embed_size = round(math.sqrt(pos_embed.shape[1])) # max size, patched_size * POS_EMBED_MAX_RATIO
|
||||
if h > pos_embed_size or w > pos_embed_size:
|
||||
# # fallback to normal pos_embed
|
||||
# return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop)
|
||||
# extend pos_embed size
|
||||
logger.warning(
|
||||
f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
|
||||
f"Add new pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
|
||||
)
|
||||
pos_embed_size = max(h, w)
|
||||
pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size)
|
||||
patched_size = max(h, w)
|
||||
grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO)
|
||||
pos_embed_size = grid_size
|
||||
pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size)
|
||||
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
|
||||
self.resolution_pos_embeds[patched_size] = pos_embed
|
||||
logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}")
|
||||
logger.info(f"Added pos_embed for size {patched_size}x{patched_size}")
|
||||
|
||||
# print(torch.allclose(pos_embed.to(torch.float32).cpu(), self.pos_embed.to(torch.float32).cpu(), atol=5e-2))
|
||||
# diff = pos_embed.to(torch.float32).cpu() - self.pos_embed.to(torch.float32).cpu()
|
||||
# print(diff.abs().max(), diff.abs().mean())
|
||||
|
||||
# insert to resolution_area_to_latent_size, by adding and sorting
|
||||
area = pos_embed_size**2
|
||||
self.resolution_area_to_latent_size.append((area, patched_size))
|
||||
self.resolution_area_to_latent_size = sorted(self.resolution_area_to_latent_size)
|
||||
|
||||
if not random_crop:
|
||||
top = (pos_embed_size - h) // 2
|
||||
|
||||
Reference in New Issue
Block a user