From 2c2e2e02d74e841f89d94f644456739bdd44dcd4 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Mon, 23 Feb 2026 15:15:59 +0900 Subject: [PATCH] fix: rank dropout handling in LoRAModule for Conv2d and Linear layers, see #2272 for details --- networks/lora_anima.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/networks/lora_anima.py b/networks/lora_anima.py index 85d30aac..4cff2819 100644 --- a/networks/lora_anima.py +++ b/networks/lora_anima.py @@ -93,10 +93,13 @@ class LoRAModule(torch.nn.Module): # rank dropout if self.rank_dropout is not None and self.training: mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout - if len(lx.size()) == 3: - mask = mask.unsqueeze(1) # for Text Encoder - elif len(lx.size()) == 4: - mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + if isinstance(self.lora_down, torch.nn.Conv2d): + # Conv2d: lora_dim is at dim 1 → [B, dim, 1, 1] + mask = mask.unsqueeze(-1).unsqueeze(-1) + else: + # Linear: lora_dim is at last dim → [B, 1, ..., 1, dim] + for _ in range(len(lx.size()) - 2): + mask = mask.unsqueeze(1) lx = lx * mask # scaling for rank dropout: treat as if the rank is changed