From 6d08c93b239d29002822be9e9aa953d075154ac2 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Tue, 10 Feb 2026 21:26:50 +0900 Subject: [PATCH] feat: enhance block swap functionality for inference and training in Anima model --- library/anima_models.py | 20 ++++++++++++++++++++ library/anima_train_utils.py | 5 ++++- library/anima_utils.py | 4 +++- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/library/anima_models.py b/library/anima_models.py index aabd14d1..d8faabf2 100644 --- a/library/anima_models.py +++ b/library/anima_models.py @@ -401,6 +401,12 @@ class Attention(nn.Module): rope_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) + if q.dtype != v.dtype: + if (not attn_params.supports_fp32 or attn_params.requires_same_dtype) and torch.is_autocast_enabled(): + # FlashAttention requires fp16/bf16, xformers require same dtype; only cast when autocast is active. + target_dtype = v.dtype # v has fp16/bf16 dtype + q = q.to(target_dtype) + k = k.to(target_dtype) # return self.compute_attention(q, k, v) qkv = [q, k, v] del q, k, v @@ -1304,6 +1310,20 @@ class Anima(nn.Module): if self.blocks_to_swap: self.blocks = save_blocks + def switch_block_swap_for_inference(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader.set_forward_only(True) + self.prepare_block_swap_before_forward() + print(f"Anima: Block swap set to forward only.") + + def switch_block_swap_for_training(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader.set_forward_only(False) + self.prepare_block_swap_before_forward() + print(f"Anima: Block swap set to forward and backward.") + def prepare_block_swap_before_forward(self): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return diff --git a/library/anima_train_utils.py b/library/anima_train_utils.py index feae032e..29b75188 100644 --- a/library/anima_train_utils.py +++ b/library/anima_train_utils.py @@ -444,7 +444,7 @@ def sample_images( args: argparse.Namespace, epoch, steps, - dit, + dit: anima_models.Anima, vae, text_encoder, tokenize_strategy, @@ -479,6 +479,8 @@ def sample_images( if text_encoder is not None: text_encoder = accelerator.unwrap_model(text_encoder) + dit.switch_block_swap_for_inference() + prompts = train_util.load_prompts(args.sample_prompts) save_dir = os.path.join(args.output_dir, "sample") os.makedirs(save_dir, exist_ok=True) @@ -514,6 +516,7 @@ def sample_images( if cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) + dit.switch_block_swap_for_training() clean_memory_on_device(accelerator.device) diff --git a/library/anima_utils.py b/library/anima_utils.py index 213188e5..d5bb24df 100644 --- a/library/anima_utils.py +++ b/library/anima_utils.py @@ -154,7 +154,9 @@ def load_anima_model( lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any. """ # dit_weight_dtype is None for fp8_scaled - assert (not fp8_scaled and dit_weight_dtype is not None) or (fp8_scaled and dit_weight_dtype is None) + assert ( + not fp8_scaled and dit_weight_dtype is not None + ) or dit_weight_dtype is None, "dit_weight_dtype should be None when fp8_scaled is True" device = torch.device(device) loading_device = torch.device(loading_device)