From bf3a13bb4e4c4d45f2bedd3fbb752f33b3ec907b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 24 Jan 2023 18:57:21 +0900 Subject: [PATCH] Fix error for loading bf16 weights --- networks/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/lora.py b/networks/lora.py index 9243f1e1..b936bfb2 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -31,7 +31,7 @@ class LoRAModule(torch.nn.Module): self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) if type(alpha) == torch.Tensor: - alpha = alpha.detach().numpy() + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える