mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix split_qkv
This commit is contained in:
@@ -540,8 +540,8 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
|
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
|
||||||
|
|
||||||
# merge up weight (sum of split_dim, rank*3)
|
# merge up weight (sum of split_dim, rank*3)
|
||||||
qkv_dim, rank = up_weights[0].size()
|
split_dim, rank = up_weights[0].size()
|
||||||
split_dim = qkv_dim // 3
|
qkv_dim = split_dim * 3
|
||||||
up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
|
up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
|
||||||
i = 0
|
i = 0
|
||||||
for j in range(3):
|
for j in range(3):
|
||||||
|
|||||||
Reference in New Issue
Block a user