mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 16:22:28 +00:00
fix: rank dropout handling in LoRAModule for Conv2d and Linear layers, see #2272 for details
This commit is contained in:
@@ -93,10 +93,13 @@ class LoRAModule(torch.nn.Module):
|
||||
# rank dropout
|
||||
if self.rank_dropout is not None and self.training:
|
||||
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
||||
if len(lx.size()) == 3:
|
||||
mask = mask.unsqueeze(1) # for Text Encoder
|
||||
elif len(lx.size()) == 4:
|
||||
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
||||
if isinstance(self.lora_down, torch.nn.Conv2d):
|
||||
# Conv2d: lora_dim is at dim 1 → [B, dim, 1, 1]
|
||||
mask = mask.unsqueeze(-1).unsqueeze(-1)
|
||||
else:
|
||||
# Linear: lora_dim is at last dim → [B, 1, ..., 1, dim]
|
||||
for _ in range(len(lx.size()) - 2):
|
||||
mask = mask.unsqueeze(1)
|
||||
lx = lx * mask
|
||||
|
||||
# scaling for rank dropout: treat as if the rank is changed
|
||||
|
||||
Reference in New Issue
Block a user