fix: keep latents 4D except DiT call

This commit is contained in:
Kohya S
2026-02-10 21:26:20 +09:00
parent 6a4e392445
commit 02a75944b3
3 changed files with 26 additions and 24 deletions

View File

@@ -41,16 +41,11 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
if (args.fp8_base or args.fp8_base_unet) and not args.fp8_scaled:
logger.warning(
"fp8_base and fp8_base_unet are not supported. Use fp8_scaled instead / fp8_baseとfp8_base_unetはサポートされていません。代わりにfp8_scaledを使用してください"
)
if args.fp8_scaled and (args.fp8_base or args.fp8_base_unet):
logger.info(
"fp8_scaled is used, so fp8_base and fp8_base_unet are ignored / fp8_scaledが使われているので、fp8_baseとfp8_base_unetは無視されます"
)
if args.fp8_base or args.fp8_base_unet:
logger.warning("fp8_base and fp8_base_unet are not supported. / fp8_baseとfp8_base_unetはサポートされていません。")
args.fp8_base = False
args.fp8_base_unet = False
args.fp8_scaled = False # Anima DiT does not support fp8_scaled
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
@@ -249,7 +244,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
def encode_images_to_latents(self, args, vae, images):
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage
return vae.encode_pixels_to_latents(images)
return vae.encode_pixels_to_latents(images) # Keep 4D for input/output
def shift_scale_latents(self, args, latents):
# Latents already normalized by vae.encode with scale
@@ -272,6 +267,8 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
anima: anima_models.Anima = unet
# Sample noise
if latents.ndim == 5: # Fallback for 5D latents (old cache)
latents = latents.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
noise = torch.randn_like(latents)
# Get noisy model input and timesteps
@@ -302,11 +299,8 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
w_latent = latents.shape[-1]
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
# Prepare block swap
if self.is_swapping_blocks:
accelerator.unwrap_model(anima).prepare_block_swap_before_forward()
# Call model
noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, [B, C, H, W] -> [B, C, 1, H, W]
with torch.set_grad_enabled(is_train), accelerator.autocast():
model_pred = anima(
noisy_model_input,
@@ -317,6 +311,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
target_attention_mask=t5_attn_mask,
source_attention_mask=attn_mask,
)
model_pred = model_pred.squeeze(2) # 5D to 4D, [B, C, 1, H, W] -> [B, C, H, W]
# Rectified flow target: noise - latents
target = noise - latents
@@ -344,10 +339,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
train_text_encoder=True,
train_unet=True,
) -> torch.Tensor:
"""Override base process_batch for caption dropout with cached text encoder outputs.
Base class now supports 4D and 5D latents, so we only need to handle caption dropout here.
"""
"""Override base process_batch for caption dropout with cached text encoder outputs."""
# Text encoder conditions
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
@@ -418,6 +410,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
@@ -425,7 +418,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
anima_train_utils.add_anima_training_arguments(parser)
parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
# parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
parser.add_argument(
"--unsloth_offload_checkpointing",
action="store_true",

View File

@@ -1008,14 +1008,19 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina
return {"sample": decoded}
def decode_to_pixels(self, latents: torch.Tensor) -> torch.Tensor:
vae_scale_factor = 2 ** len(self.temperal_downsample)
# latents = qwen_image_utils.unpack_latents(latent, height, width, vae_scale_factor)
is_4d = latents.dim() == 4
if is_4d:
latents = latents.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
latents = latents.to(self.dtype)
latents_mean = torch.tensor(self.latents_mean).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / torch.tensor(self.latents_std).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
image = self.decode(latents, return_dict=False)[0][:, :, 0] # -1 to 1
# return (image * 0.5 + 0.5).clamp(0.0, 1.0) # Convert to [0, 1] range
image = self.decode(latents, return_dict=False)[0] # -1 to 1
if is_4d:
image = image.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
return image.clamp(-1.0, 1.0)
def encode_pixels_to_latents(self, pixels: torch.Tensor) -> torch.Tensor:
@@ -1032,7 +1037,8 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina
# pixels = (pixels * 2.0 - 1.0).clamp(-1.0, 1.0)
# Handle 2D input by adding temporal dimension
if pixels.dim() == 4:
is_4d = pixels.dim() == 4
if is_4d:
pixels = pixels.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
pixels = pixels.to(self.dtype)
@@ -1047,6 +1053,9 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina
latents_std = 1.0 / torch.tensor(self.latents_std).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * latents_std
if is_4d:
latents = latents.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
return latents
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:

View File

@@ -291,7 +291,7 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
Qwen Image VAE accepts inputs in (B, C, H, W) or (B, C, 1, H, W) shape.
Returns latents in (B, 16, 1, H/8, W/8) shape on CPU.
"""
latents = vae.encode_pixels_to_latents(img_tensor)
latents = vae.encode_pixels_to_latents(img_tensor) # Keep 4D for input/output
return latents.to("cpu")
self._default_cache_batch_latents(