Remove rank stabilization

This commit is contained in:
rockerBOO
2025-03-24 04:22:12 -04:00
parent 85928dd3b0
commit 58bdf85ab4

View File

@@ -66,10 +66,7 @@ class LoRAModule(torch.nn.Module):
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
rank_factor = self.lora_dim
if rank_stabilized:
rank_factor = math.sqrt(rank_factor)
self.scale = alpha / rank_factor
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
self.split_dims = split_dims