Fix error for loading bf16 weights

This commit is contained in:
Kohya S
2023-01-24 18:57:21 +09:00
parent 93df55d597
commit bf3a13bb4e

View File

@@ -31,7 +31,7 @@ class LoRAModule(torch.nn.Module):
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える