fix: improve pos_embed handling for oversized images and update resolution_area_to_latent_size, when sample image size > train image size

This commit is contained in:
Kohya S
2024-11-30 18:25:50 +09:00
parent 2a61fc0784
commit 9c885e549d

View File

@@ -1017,22 +1017,35 @@ class MMDiT(nn.Module):
patched_size = patched_size_
break
if patched_size is None:
raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")
# raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")
# use largest latent size
patched_size = self.resolution_area_to_latent_size[-1][1]
pos_embed = self.resolution_pos_embeds[patched_size]
pos_embed_size = round(math.sqrt(pos_embed.shape[1]))
pos_embed_size = round(math.sqrt(pos_embed.shape[1])) # max size, patched_size * POS_EMBED_MAX_RATIO
if h > pos_embed_size or w > pos_embed_size:
# # fallback to normal pos_embed
# return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop)
# extend pos_embed size
logger.warning(
f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
f"Add new pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
)
pos_embed_size = max(h, w)
pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size)
patched_size = max(h, w)
grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO)
pos_embed_size = grid_size
pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
self.resolution_pos_embeds[patched_size] = pos_embed
logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}")
logger.info(f"Added pos_embed for size {patched_size}x{patched_size}")
# print(torch.allclose(pos_embed.to(torch.float32).cpu(), self.pos_embed.to(torch.float32).cpu(), atol=5e-2))
# diff = pos_embed.to(torch.float32).cpu() - self.pos_embed.to(torch.float32).cpu()
# print(diff.abs().max(), diff.abs().mean())
# insert to resolution_area_to_latent_size, by adding and sorting
area = pos_embed_size**2
self.resolution_area_to_latent_size.append((area, patched_size))
self.resolution_area_to_latent_size = sorted(self.resolution_area_to_latent_size)
if not random_crop:
top = (pos_embed_size - h) // 2