diff --git a/library/hunyuan_image_vae.py b/library/hunyuan_image_vae.py index 6f6eea22..b66854e5 100644 --- a/library/hunyuan_image_vae.py +++ b/library/hunyuan_image_vae.py @@ -449,7 +449,7 @@ class HunyuanVAE2D(nn.Module): """ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) for x in range(blend_extent): - b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: @@ -467,7 +467,7 @@ class HunyuanVAE2D(nn.Module): """ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) for y in range(blend_extent): - b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) return b def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor: @@ -478,9 +478,14 @@ class HunyuanVAE2D(nn.Module): Parameters ---------- x : torch.Tensor - Input tensor of shape (B, C, T, H, W). + Input tensor of shape (B, C, T, H, W) or (B, C, H, W). """ - B, C, T, H, W = x.shape + # Handle 5D input (B, C, T, H, W) by removing time dimension + original_ndim = x.ndim + if original_ndim == 5: + x = x.squeeze(2) + + B, C, H, W = x.shape overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) row_limit = self.tile_latent_min_size - blend_extent @@ -489,7 +494,7 @@ class HunyuanVAE2D(nn.Module): for i in range(0, H, overlap_size): row = [] for j in range(0, W, overlap_size): - tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] tile = self.encoder(tile) row.append(tile) rows.append(row) @@ -502,7 +507,7 @@ class HunyuanVAE2D(nn.Module): tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) - result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_row.append(tile[:, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=-1)) moments = torch.cat(result_rows, dim=-2)