mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
feat: Add option to split projection layers and apply LoRA
This commit is contained in:
@@ -18,7 +18,7 @@ def main(file):
|
||||
|
||||
keys = list(sd.keys())
|
||||
for key in keys:
|
||||
if "lora_up" in key or "lora_down" in key:
|
||||
if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key:
|
||||
values.append((key, sd[key]))
|
||||
print(f"number of LoRA modules: {len(values)}")
|
||||
|
||||
|
||||
@@ -266,11 +266,12 @@ def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
||||
if sds_key + ".lora_down.weight" not in sds_sd:
|
||||
return
|
||||
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
||||
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
|
||||
sd_lora_rank = down_weight.shape[0]
|
||||
|
||||
# scale weight by alpha and dim
|
||||
rank = down_weight.shape[0]
|
||||
alpha = sds_sd.pop(sds_key + ".alpha")
|
||||
scale = alpha / rank
|
||||
scale = alpha / sd_lora_rank
|
||||
|
||||
# calculate scale_down and scale_up
|
||||
scale_down = scale
|
||||
@@ -279,23 +280,49 @@ def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
|
||||
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
|
||||
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
||||
|
||||
num_splits = len(ait_keys)
|
||||
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
|
||||
|
||||
# down_weight is copied to each split
|
||||
ait_sd.update({k: down_weight * scale_down for k in ait_down_keys})
|
||||
down_weight = down_weight * scale_down
|
||||
up_weight = up_weight * scale_up
|
||||
|
||||
# calculate dims if not provided
|
||||
num_splits = len(ait_keys)
|
||||
if dims is None:
|
||||
dims = [up_weight.shape[0] // num_splits] * num_splits
|
||||
else:
|
||||
assert sum(dims) == up_weight.shape[0]
|
||||
|
||||
# up_weight is split to each split
|
||||
ait_sd.update({k: v * scale_up for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))})
|
||||
# check upweight is sparse or not
|
||||
is_sparse = False
|
||||
if sd_lora_rank % num_splits == 0:
|
||||
ait_rank = sd_lora_rank // num_splits
|
||||
is_sparse = True
|
||||
i = 0
|
||||
for j in range(len(dims)):
|
||||
for k in range(len(dims)):
|
||||
if j == k:
|
||||
continue
|
||||
is_sparse = is_sparse and torch.all(up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0)
|
||||
i += dims[j]
|
||||
if is_sparse:
|
||||
logger.info(f"weight is sparse: {sds_key}")
|
||||
|
||||
# make ai-toolkit weight
|
||||
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
|
||||
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
||||
if not is_sparse:
|
||||
# down_weight is copied to each split
|
||||
ait_sd.update({k: down_weight for k in ait_down_keys})
|
||||
|
||||
# up_weight is split to each split
|
||||
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))})
|
||||
else:
|
||||
# down_weight is chunked to each split
|
||||
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))})
|
||||
|
||||
# up_weight is sparse: only non-zero values are copied to each split
|
||||
i = 0
|
||||
for j in range(len(dims)):
|
||||
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
|
||||
i += dims[j]
|
||||
|
||||
|
||||
def convert_sd_scripts_to_ai_toolkit(sds_sd):
|
||||
|
||||
@@ -39,6 +39,7 @@ class LoRAModule(torch.nn.Module):
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
split_dims: Optional[List[int]] = None,
|
||||
):
|
||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||
super().__init__()
|
||||
@@ -52,16 +53,34 @@ class LoRAModule(torch.nn.Module):
|
||||
out_dim = org_module.out_features
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
self.split_dims = split_dims
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
kernel_size = org_module.kernel_size
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||
if split_dims is None:
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
kernel_size = org_module.kernel_size
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||
else:
|
||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||
|
||||
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||
torch.nn.init.zeros_(self.lora_up.weight)
|
||||
else:
|
||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||
# conv2d not supported
|
||||
assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
|
||||
assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear"
|
||||
# print(f"split_dims: {split_dims}")
|
||||
self.lora_down = torch.nn.ModuleList(
|
||||
[torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
|
||||
)
|
||||
self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
|
||||
for lora_down in self.lora_down:
|
||||
torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
|
||||
for lora_up in self.lora_up:
|
||||
torch.nn.init.zeros_(lora_up.weight)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
@@ -70,9 +89,6 @@ class LoRAModule(torch.nn.Module):
|
||||
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||||
|
||||
# same as microsoft's
|
||||
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||
torch.nn.init.zeros_(self.lora_up.weight)
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
self.dropout = dropout
|
||||
@@ -92,30 +108,56 @@ class LoRAModule(torch.nn.Module):
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return org_forwarded
|
||||
|
||||
lx = self.lora_down(x)
|
||||
if self.split_dims is None:
|
||||
lx = self.lora_down(x)
|
||||
|
||||
# normal dropout
|
||||
if self.dropout is not None and self.training:
|
||||
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||
# normal dropout
|
||||
if self.dropout is not None and self.training:
|
||||
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||
|
||||
# rank dropout
|
||||
if self.rank_dropout is not None and self.training:
|
||||
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
||||
if len(lx.size()) == 3:
|
||||
mask = mask.unsqueeze(1) # for Text Encoder
|
||||
elif len(lx.size()) == 4:
|
||||
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
||||
lx = lx * mask
|
||||
# rank dropout
|
||||
if self.rank_dropout is not None and self.training:
|
||||
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
||||
if len(lx.size()) == 3:
|
||||
mask = mask.unsqueeze(1) # for Text Encoder
|
||||
elif len(lx.size()) == 4:
|
||||
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
||||
lx = lx * mask
|
||||
|
||||
# scaling for rank dropout: treat as if the rank is changed
|
||||
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
||||
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
||||
# scaling for rank dropout: treat as if the rank is changed
|
||||
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
||||
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
||||
else:
|
||||
scale = self.scale
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
else:
|
||||
scale = self.scale
|
||||
lxs = [lora_down(x) for lora_down in self.lora_down]
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
# normal dropout
|
||||
if self.dropout is not None and self.training:
|
||||
lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs]
|
||||
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
# rank dropout
|
||||
if self.rank_dropout is not None and self.training:
|
||||
masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs]
|
||||
for i in range(len(lxs)):
|
||||
if len(lx.size()) == 3:
|
||||
masks[i] = masks[i].unsqueeze(1)
|
||||
elif len(lx.size()) == 4:
|
||||
masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1)
|
||||
lxs[i] = lxs[i] * masks[i]
|
||||
|
||||
# scaling for rank dropout: treat as if the rank is changed
|
||||
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
||||
else:
|
||||
scale = self.scale
|
||||
|
||||
lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
|
||||
|
||||
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
|
||||
|
||||
|
||||
class LoRAInfModule(LoRAModule):
|
||||
@@ -152,31 +194,50 @@ class LoRAInfModule(LoRAModule):
|
||||
if device is None:
|
||||
device = org_device
|
||||
|
||||
# get up/down weight
|
||||
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
||||
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
|
||||
if self.split_dims is None:
|
||||
# get up/down weight
|
||||
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
|
||||
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
||||
|
||||
# merge weight
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
weight
|
||||
+ self.multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* self.scale
|
||||
)
|
||||
# merge weight
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
weight
|
||||
+ self.multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* self.scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + self.multiplier * conved * self.scale
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
self.org_module.load_state_dict(org_sd)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + self.multiplier * conved * self.scale
|
||||
# split_dims
|
||||
total_dims = sum(self.split_dims)
|
||||
for i in range(len(self.split_dims)):
|
||||
# get up/down weight
|
||||
down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim)
|
||||
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank)
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
self.org_module.load_state_dict(org_sd)
|
||||
# pad up_weight -> (total_dims, rank)
|
||||
padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
|
||||
padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight
|
||||
|
||||
# merge weight
|
||||
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
self.org_module.load_state_dict(org_sd)
|
||||
|
||||
# 復元できるマージのため、このモジュールのweightを返す
|
||||
def get_weight(self, multiplier=None):
|
||||
@@ -211,7 +272,14 @@ class LoRAInfModule(LoRAModule):
|
||||
|
||||
def default_forward(self, x):
|
||||
# logger.info(f"default_forward {self.lora_name} {x.size()}")
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
if self.split_dims is None:
|
||||
lx = self.lora_down(x)
|
||||
lx = self.lora_up(lx)
|
||||
return self.org_forward(x) + lx * self.multiplier * self.scale
|
||||
else:
|
||||
lxs = [lora_down(x) for lora_down in self.lora_down]
|
||||
lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
|
||||
return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale
|
||||
|
||||
def forward(self, x):
|
||||
if not self.enabled:
|
||||
@@ -257,6 +325,11 @@ def create_network(
|
||||
if train_blocks is not None:
|
||||
assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}"
|
||||
|
||||
# split qkv
|
||||
split_qkv = kwargs.get("split_qkv", False)
|
||||
if split_qkv is not None:
|
||||
split_qkv = True if split_qkv == "True" else False
|
||||
|
||||
# すごく引数が多いな ( ^ω^)・・・
|
||||
network = LoRANetwork(
|
||||
text_encoders,
|
||||
@@ -270,6 +343,7 @@ def create_network(
|
||||
conv_lora_dim=conv_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
train_blocks=train_blocks,
|
||||
split_qkv=split_qkv,
|
||||
varbose=True,
|
||||
)
|
||||
|
||||
@@ -311,10 +385,34 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
||||
modules_dim[lora_name] = dim
|
||||
# logger.info(lora_name, value.size(), dim)
|
||||
|
||||
# # split qkv
|
||||
# double_qkv_rank = None
|
||||
# single_qkv_rank = None
|
||||
# rank = None
|
||||
# for lora_name, dim in modules_dim.items():
|
||||
# if "double" in lora_name and "qkv" in lora_name:
|
||||
# double_qkv_rank = dim
|
||||
# elif "single" in lora_name and "linear1" in lora_name:
|
||||
# single_qkv_rank = dim
|
||||
# elif rank is None:
|
||||
# rank = dim
|
||||
# if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None:
|
||||
# break
|
||||
# split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or (
|
||||
# single_qkv_rank is not None and single_qkv_rank != rank
|
||||
# )
|
||||
split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined
|
||||
|
||||
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||
|
||||
network = LoRANetwork(
|
||||
text_encoders, flux, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
||||
text_encoders,
|
||||
flux,
|
||||
multiplier=multiplier,
|
||||
modules_dim=modules_dim,
|
||||
modules_alpha=modules_alpha,
|
||||
module_class=module_class,
|
||||
split_qkv=split_qkv,
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
@@ -344,6 +442,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
modules_dim: Optional[Dict[str, int]] = None,
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
train_blocks: Optional[str] = None,
|
||||
split_qkv: bool = False,
|
||||
varbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -357,6 +456,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.train_blocks = train_blocks if train_blocks is not None else "all"
|
||||
self.split_qkv = split_qkv
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
@@ -373,6 +473,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
logger.info(
|
||||
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||
)
|
||||
if self.split_qkv:
|
||||
logger.info(f"split qkv for LoRA")
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
@@ -420,6 +522,14 @@ class LoRANetwork(torch.nn.Module):
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
# qkv split
|
||||
split_dims = None
|
||||
if is_flux and split_qkv:
|
||||
if "double" in lora_name and "qkv" in lora_name:
|
||||
split_dims = [3072] * 3
|
||||
elif "single" in lora_name and "linear1" in lora_name:
|
||||
split_dims = [3072] * 3 + [12288]
|
||||
|
||||
lora = module_class(
|
||||
lora_name,
|
||||
child_module,
|
||||
@@ -429,6 +539,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
split_dims=split_dims,
|
||||
)
|
||||
loras.append(lora)
|
||||
return loras, skipped
|
||||
@@ -492,6 +603,111 @@ class LoRANetwork(torch.nn.Module):
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
# override to convert original weight to splitted qkv weight
|
||||
if not self.split_qkv:
|
||||
return super().load_state_dict(state_dict, strict)
|
||||
|
||||
# split qkv
|
||||
for key in list(state_dict.keys()):
|
||||
if "double" in key and "qkv" in key:
|
||||
split_dims = [3072] * 3
|
||||
elif "single" in key and "linear1" in key:
|
||||
split_dims = [3072] * 3 + [12288]
|
||||
else:
|
||||
continue
|
||||
|
||||
weight = state_dict[key]
|
||||
lora_name = key.split(".")[0]
|
||||
if "lora_down" in key and "weight" in key:
|
||||
# dense weight (rank*3, in_dim)
|
||||
split_weight = torch.chunk(weight, len(split_dims), dim=0)
|
||||
for i, split_w in enumerate(split_weight):
|
||||
state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w
|
||||
|
||||
del state_dict[key]
|
||||
# print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}")
|
||||
elif "lora_up" in key and "weight" in key:
|
||||
# sparse weight (out_dim=sum(split_dims), rank*3)
|
||||
rank = weight.size(1) // len(split_dims)
|
||||
i = 0
|
||||
for j in range(len(split_dims)):
|
||||
state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dims[j], j * rank : (j + 1) * rank]
|
||||
i += split_dims[j]
|
||||
del state_dict[key]
|
||||
|
||||
# # check is sparse
|
||||
# i = 0
|
||||
# is_zero = True
|
||||
# for j in range(len(split_dims)):
|
||||
# for k in range(len(split_dims)):
|
||||
# if j == k:
|
||||
# continue
|
||||
# is_zero = is_zero and torch.all(weight[i : i + split_dims[j], k * rank : (k + 1) * rank] == 0)
|
||||
# i += split_dims[j]
|
||||
# if not is_zero:
|
||||
# logger.warning(f"weight is not sparse: {key}")
|
||||
# else:
|
||||
# logger.info(f"weight is sparse: {key}")
|
||||
|
||||
# print(
|
||||
# f"split {key}: {weight.shape} to {[state_dict[k].shape for k in [f'{lora_name}.lora_up.{j}.weight' for j in range(len(split_dims))]]}"
|
||||
# )
|
||||
|
||||
# alpha is unchanged
|
||||
|
||||
return super().load_state_dict(state_dict, strict)
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
if not self.split_qkv:
|
||||
return super().state_dict(destination, prefix, keep_vars)
|
||||
|
||||
# merge qkv
|
||||
state_dict = super().state_dict(destination, prefix, keep_vars)
|
||||
new_state_dict = {}
|
||||
for key in list(state_dict.keys()):
|
||||
if "double" in key and "qkv" in key:
|
||||
split_dims = [3072] * 3
|
||||
elif "single" in key and "linear1" in key:
|
||||
split_dims = [3072] * 3 + [12288]
|
||||
else:
|
||||
new_state_dict[key] = state_dict[key]
|
||||
continue
|
||||
|
||||
if key not in state_dict:
|
||||
continue # already merged
|
||||
|
||||
lora_name = key.split(".")[0]
|
||||
|
||||
# (rank, in_dim) * 3
|
||||
down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))]
|
||||
# (split dim, rank) * 3
|
||||
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]
|
||||
|
||||
alpha = state_dict.pop(f"{lora_name}.alpha")
|
||||
|
||||
# merge down weight
|
||||
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
|
||||
|
||||
# merge up weight (sum of split_dim, rank*3)
|
||||
rank = up_weights[0].size(1)
|
||||
up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
|
||||
i = 0
|
||||
for j in range(len(split_dims)):
|
||||
up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j]
|
||||
i += split_dims[j]
|
||||
|
||||
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
|
||||
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
|
||||
new_state_dict[f"{lora_name}.alpha"] = alpha
|
||||
|
||||
# print(
|
||||
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
|
||||
# )
|
||||
print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha")
|
||||
|
||||
return new_state_dict
|
||||
|
||||
def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
|
||||
|
||||
Reference in New Issue
Block a user