feat: enhance block swap functionality for inference and training in Anima model

This commit is contained in:
Kohya S
2026-02-10 21:26:50 +09:00
parent 02a75944b3
commit 6d08c93b23
3 changed files with 27 additions and 2 deletions

View File

@@ -401,6 +401,12 @@ class Attention(nn.Module):
rope_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
if q.dtype != v.dtype:
if (not attn_params.supports_fp32 or attn_params.requires_same_dtype) and torch.is_autocast_enabled():
# FlashAttention requires fp16/bf16, xformers require same dtype; only cast when autocast is active.
target_dtype = v.dtype # v has fp16/bf16 dtype
q = q.to(target_dtype)
k = k.to(target_dtype)
# return self.compute_attention(q, k, v)
qkv = [q, k, v]
del q, k, v
@@ -1304,6 +1310,20 @@ class Anima(nn.Module):
if self.blocks_to_swap:
self.blocks = save_blocks
def switch_block_swap_for_inference(self):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self.offloader.set_forward_only(True)
self.prepare_block_swap_before_forward()
print(f"Anima: Block swap set to forward only.")
def switch_block_swap_for_training(self):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self.offloader.set_forward_only(False)
self.prepare_block_swap_before_forward()
print(f"Anima: Block swap set to forward and backward.")
def prepare_block_swap_before_forward(self):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return

View File

@@ -444,7 +444,7 @@ def sample_images(
args: argparse.Namespace,
epoch,
steps,
dit,
dit: anima_models.Anima,
vae,
text_encoder,
tokenize_strategy,
@@ -479,6 +479,8 @@ def sample_images(
if text_encoder is not None:
text_encoder = accelerator.unwrap_model(text_encoder)
dit.switch_block_swap_for_inference()
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = os.path.join(args.output_dir, "sample")
os.makedirs(save_dir, exist_ok=True)
@@ -514,6 +516,7 @@ def sample_images(
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
dit.switch_block_swap_for_training()
clean_memory_on_device(accelerator.device)

View File

@@ -154,7 +154,9 @@ def load_anima_model(
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
"""
# dit_weight_dtype is None for fp8_scaled
assert (not fp8_scaled and dit_weight_dtype is not None) or (fp8_scaled and dit_weight_dtype is None)
assert (
not fp8_scaled and dit_weight_dtype is not None
) or dit_weight_dtype is None, "dit_weight_dtype should be None when fp8_scaled is True"
device = torch.device(device)
loading_device = torch.device(loading_device)