diff --git a/networks/lora_anima.py b/networks/lora_anima.py index 85d30aac..4cff2819 100644 --- a/networks/lora_anima.py +++ b/networks/lora_anima.py @@ -93,10 +93,13 @@ class LoRAModule(torch.nn.Module): # rank dropout if self.rank_dropout is not None and self.training: mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout - if len(lx.size()) == 3: - mask = mask.unsqueeze(1) # for Text Encoder - elif len(lx.size()) == 4: - mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + if isinstance(self.lora_down, torch.nn.Conv2d): + # Conv2d: lora_dim is at dim 1 → [B, dim, 1, 1] + mask = mask.unsqueeze(-1).unsqueeze(-1) + else: + # Linear: lora_dim is at last dim → [B, 1, ..., 1, dim] + for _ in range(len(lx.size()) - 2): + mask = mask.unsqueeze(1) lx = lx * mask # scaling for rank dropout: treat as if the rank is changed