From 609d1292f6e262b27a8c5b2849e7bf0df2ecd7a8 Mon Sep 17 00:00:00 2001 From: duongve13112002 <71595470+duongve13112002@users.noreply.github.com> Date: Mon, 23 Feb 2026 13:13:40 +0700 Subject: [PATCH] Fix the LoRA dropout issue in the Anima model during LoRA training. (#2272) * Support network_reg_alphas and fix bug when setting rank_dropout in training lora for anima model * Update anima_train_network.md * Update anima_train_network.md * Remove network_reg_alphas * Update document --- docs/anima_train_network.md | 2 +- networks/lora_anima.py | 2 +- networks/lora_flux.py | 13 ++++++++----- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/anima_train_network.md b/docs/anima_train_network.md index f97aa975..5d67ae36 100644 --- a/docs/anima_train_network.md +++ b/docs/anima_train_network.md @@ -652,4 +652,4 @@ The following metadata is saved in the LoRA model file: * `ss_sigmoid_scale` * `ss_discrete_flow_shift` - + \ No newline at end of file diff --git a/networks/lora_anima.py b/networks/lora_anima.py index 224ef20c..9413e8c8 100644 --- a/networks/lora_anima.py +++ b/networks/lora_anima.py @@ -636,4 +636,4 @@ class LoRANetwork(torch.nn.Module): scalednorm = updown.norm() * ratio norms.append(scalednorm.item()) - return keys_scaled, sum(norms) / len(norms), max(norms) + return keys_scaled, sum(norms) / len(norms), max(norms) \ No newline at end of file diff --git a/networks/lora_flux.py b/networks/lora_flux.py index d74d0172..947733fe 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -141,10 +141,13 @@ class LoRAModule(torch.nn.Module): # rank dropout if self.rank_dropout is not None and self.training: mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout - if len(lx.size()) == 3: - mask = mask.unsqueeze(1) # for Text Encoder - elif len(lx.size()) == 4: - mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + if isinstance(self.lora_down, torch.nn.Conv2d): + # Conv2d: lora_dim is at dim 1 → [B, dim, 1, 1] + mask = mask.unsqueeze(-1).unsqueeze(-1) + else: + # Linear: lora_dim is at last dim → [B, 1, ..., 1, dim] + for _ in range(len(lx.size()) - 2): + mask = mask.unsqueeze(1) lx = lx * mask # scaling for rank dropout: treat as if the rank is changed @@ -1445,4 +1448,4 @@ class LoRANetwork(torch.nn.Module): scalednorm = updown.norm() * ratio norms.append(scalednorm.item()) - return keys_scaled, sum(norms) / len(norms), max(norms) + return keys_scaled, sum(norms) / len(norms), max(norms) \ No newline at end of file