diff --git a/anima_train_network.py b/anima_train_network.py index 812fda7d..ad4c771c 100644 --- a/anima_train_network.py +++ b/anima_train_network.py @@ -41,16 +41,11 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup], ): - if (args.fp8_base or args.fp8_base_unet) and not args.fp8_scaled: - logger.warning( - "fp8_base and fp8_base_unet are not supported. Use fp8_scaled instead / fp8_baseとfp8_base_unetはサポートされていません。代わりにfp8_scaledを使用してください" - ) - if args.fp8_scaled and (args.fp8_base or args.fp8_base_unet): - logger.info( - "fp8_scaled is used, so fp8_base and fp8_base_unet are ignored / fp8_scaledが使われているので、fp8_baseとfp8_base_unetは無視されます" - ) + if args.fp8_base or args.fp8_base_unet: + logger.warning("fp8_base and fp8_base_unet are not supported. / fp8_baseとfp8_base_unetはサポートされていません。") args.fp8_base = False args.fp8_base_unet = False + args.fp8_scaled = False # Anima DiT does not support fp8_scaled if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled") @@ -249,7 +244,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): def encode_images_to_latents(self, args, vae, images): vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage - return vae.encode_pixels_to_latents(images) + return vae.encode_pixels_to_latents(images) # Keep 4D for input/output def shift_scale_latents(self, args, latents): # Latents already normalized by vae.encode with scale @@ -272,6 +267,8 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): anima: anima_models.Anima = unet # Sample noise + if latents.ndim == 5: # Fallback for 5D latents (old cache) + latents = latents.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W] noise = torch.randn_like(latents) # Get noisy model input and timesteps @@ -302,11 +299,8 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): w_latent = latents.shape[-1] padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device) - # Prepare block swap - if self.is_swapping_blocks: - accelerator.unwrap_model(anima).prepare_block_swap_before_forward() - # Call model + noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, [B, C, H, W] -> [B, C, 1, H, W] with torch.set_grad_enabled(is_train), accelerator.autocast(): model_pred = anima( noisy_model_input, @@ -317,6 +311,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): target_attention_mask=t5_attn_mask, source_attention_mask=attn_mask, ) + model_pred = model_pred.squeeze(2) # 5D to 4D, [B, C, 1, H, W] -> [B, C, H, W] # Rectified flow target: noise - latents target = noise - latents @@ -344,10 +339,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): train_text_encoder=True, train_unet=True, ) -> torch.Tensor: - """Override base process_batch for caption dropout with cached text encoder outputs. - - Base class now supports 4D and 5D latents, so we only need to handle caption dropout here. - """ + """Override base process_batch for caption dropout with cached text encoder outputs.""" # Text encoder conditions text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) @@ -418,6 +410,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here accelerator.unwrap_model(unet).prepare_block_swap_before_forward() @@ -425,7 +418,7 @@ def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() train_util.add_dit_training_arguments(parser) anima_train_utils.add_anima_training_arguments(parser) - parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う") + # parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う") parser.add_argument( "--unsloth_offload_checkpointing", action="store_true", diff --git a/library/qwen_image_autoencoder_kl.py b/library/qwen_image_autoencoder_kl.py index 61fc7550..2d1ce692 100644 --- a/library/qwen_image_autoencoder_kl.py +++ b/library/qwen_image_autoencoder_kl.py @@ -1008,14 +1008,19 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina return {"sample": decoded} def decode_to_pixels(self, latents: torch.Tensor) -> torch.Tensor: - vae_scale_factor = 2 ** len(self.temperal_downsample) - # latents = qwen_image_utils.unpack_latents(latent, height, width, vae_scale_factor) + is_4d = latents.dim() == 4 + if is_4d: + latents = latents.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W] + latents = latents.to(self.dtype) latents_mean = torch.tensor(self.latents_mean).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype) latents_std = 1.0 / torch.tensor(self.latents_std).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean - image = self.decode(latents, return_dict=False)[0][:, :, 0] # -1 to 1 - # return (image * 0.5 + 0.5).clamp(0.0, 1.0) # Convert to [0, 1] range + + image = self.decode(latents, return_dict=False)[0] # -1 to 1 + if is_4d: + image = image.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W] + return image.clamp(-1.0, 1.0) def encode_pixels_to_latents(self, pixels: torch.Tensor) -> torch.Tensor: @@ -1032,7 +1037,8 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina # pixels = (pixels * 2.0 - 1.0).clamp(-1.0, 1.0) # Handle 2D input by adding temporal dimension - if pixels.dim() == 4: + is_4d = pixels.dim() == 4 + if is_4d: pixels = pixels.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W] pixels = pixels.to(self.dtype) @@ -1047,6 +1053,9 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina latents_std = 1.0 / torch.tensor(self.latents_std).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype) latents = (latents - latents_mean) * latents_std + if is_4d: + latents = latents.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W] + return latents def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: diff --git a/library/strategy_anima.py b/library/strategy_anima.py index 3d225691..143159b2 100644 --- a/library/strategy_anima.py +++ b/library/strategy_anima.py @@ -291,7 +291,7 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy): Qwen Image VAE accepts inputs in (B, C, H, W) or (B, C, 1, H, W) shape. Returns latents in (B, 16, 1, H/8, W/8) shape on CPU. """ - latents = vae.encode_pixels_to_latents(img_tensor) + latents = vae.encode_pixels_to_latents(img_tensor) # Keep 4D for input/output return latents.to("cpu") self._default_cache_batch_latents(