feat: Add option to split projection layers and apply LoRA

This commit is contained in:
Kohya S
2024-08-24 16:35:43 +09:00
parent 2e89cd2cc6
commit cf689e7aa6
4 changed files with 323 additions and 66 deletions

View File

@@ -9,6 +9,20 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows: The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
Aug 24, 2024 (update 2):
__Experimental__ Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them in FLUX.1 LoRA training. Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available).
The number of parameters may increase slightly, so the expressiveness may increase, but the training time may be longer. No detailed verification has been done.
This implementation is experimental, so it may be deprecated or changed in the future.
The .safetensors file of the trained model is compatible with the normal LoRA model of sd-scripts, so it should be usable in inference environments such as ComfyUI as it is. Also, converting it to AI-toolkit (Diffusers) format with `convert_flux_lora.py` will reduce the size. It should be no problem to convert it if you use it in the inference environment.
Technical details: In the implementation of Black Forest Labs' model, the projection layers of q/k/v (and txt in single blocks) are concatenated into one. If LoRA is added there as it is, the LoRA module is only one, and the dimension is large. In contrast, in the implementation of Diffusers, the projection layers of q/k/v/txt are separated. Therefore, the LoRA module is applied to q/k/v/txt separately, and the dimension is smaller. This option is for training LoRA similar to the latter.
The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large.
Aug 24, 2024: Aug 24, 2024:
Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified. Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified.

View File

@@ -18,7 +18,7 @@ def main(file):
keys = list(sd.keys()) keys = list(sd.keys())
for key in keys: for key in keys:
if "lora_up" in key or "lora_down" in key: if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key:
values.append((key, sd[key])) values.append((key, sd[key]))
print(f"number of LoRA modules: {len(values)}") print(f"number of LoRA modules: {len(values)}")

View File

@@ -266,11 +266,12 @@ def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
if sds_key + ".lora_down.weight" not in sds_sd: if sds_key + ".lora_down.weight" not in sds_sd:
return return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight") down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
sd_lora_rank = down_weight.shape[0]
# scale weight by alpha and dim # scale weight by alpha and dim
rank = down_weight.shape[0]
alpha = sds_sd.pop(sds_key + ".alpha") alpha = sds_sd.pop(sds_key + ".alpha")
scale = alpha / rank scale = alpha / sd_lora_rank
# calculate scale_down and scale_up # calculate scale_down and scale_up
scale_down = scale scale_down = scale
@@ -279,23 +280,49 @@ def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
scale_down *= 2 scale_down *= 2
scale_up /= 2 scale_up /= 2
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] down_weight = down_weight * scale_down
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] up_weight = up_weight * scale_up
num_splits = len(ait_keys)
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
# down_weight is copied to each split
ait_sd.update({k: down_weight * scale_down for k in ait_down_keys})
# calculate dims if not provided # calculate dims if not provided
num_splits = len(ait_keys)
if dims is None: if dims is None:
dims = [up_weight.shape[0] // num_splits] * num_splits dims = [up_weight.shape[0] // num_splits] * num_splits
else: else:
assert sum(dims) == up_weight.shape[0] assert sum(dims) == up_weight.shape[0]
# up_weight is split to each split # check upweight is sparse or not
ait_sd.update({k: v * scale_up for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) is_sparse = False
if sd_lora_rank % num_splits == 0:
ait_rank = sd_lora_rank // num_splits
is_sparse = True
i = 0
for j in range(len(dims)):
for k in range(len(dims)):
if j == k:
continue
is_sparse = is_sparse and torch.all(up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0)
i += dims[j]
if is_sparse:
logger.info(f"weight is sparse: {sds_key}")
# make ai-toolkit weight
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
if not is_sparse:
# down_weight is copied to each split
ait_sd.update({k: down_weight for k in ait_down_keys})
# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))})
else:
# down_weight is chunked to each split
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))})
# up_weight is sparse: only non-zero values are copied to each split
i = 0
for j in range(len(dims)):
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
i += dims[j]
def convert_sd_scripts_to_ai_toolkit(sds_sd): def convert_sd_scripts_to_ai_toolkit(sds_sd):

View File

@@ -39,6 +39,7 @@ class LoRAModule(torch.nn.Module):
dropout=None, dropout=None,
rank_dropout=None, rank_dropout=None,
module_dropout=None, module_dropout=None,
split_dims: Optional[List[int]] = None,
): ):
"""if alpha == 0 or None, alpha is rank (no scaling).""" """if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__() super().__init__()
@@ -52,16 +53,34 @@ class LoRAModule(torch.nn.Module):
out_dim = org_module.out_features out_dim = org_module.out_features
self.lora_dim = lora_dim self.lora_dim = lora_dim
self.split_dims = split_dims
if org_module.__class__.__name__ == "Conv2d": if split_dims is None:
kernel_size = org_module.kernel_size if org_module.__class__.__name__ == "Conv2d":
stride = org_module.stride kernel_size = org_module.kernel_size
padding = org_module.padding stride = org_module.stride
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) padding = org_module.padding
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
else: else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) # conv2d not supported
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear"
# print(f"split_dims: {split_dims}")
self.lora_down = torch.nn.ModuleList(
[torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
)
self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
for lora_down in self.lora_down:
torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
for lora_up in self.lora_up:
torch.nn.init.zeros_(lora_up.weight)
if type(alpha) == torch.Tensor: if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
@@ -70,9 +89,6 @@ class LoRAModule(torch.nn.Module):
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
# same as microsoft's # same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
self.multiplier = multiplier self.multiplier = multiplier
self.org_module = org_module # remove in applying self.org_module = org_module # remove in applying
self.dropout = dropout self.dropout = dropout
@@ -92,30 +108,56 @@ class LoRAModule(torch.nn.Module):
if torch.rand(1) < self.module_dropout: if torch.rand(1) < self.module_dropout:
return org_forwarded return org_forwarded
lx = self.lora_down(x) if self.split_dims is None:
lx = self.lora_down(x)
# normal dropout # normal dropout
if self.dropout is not None and self.training: if self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(lx, p=self.dropout) lx = torch.nn.functional.dropout(lx, p=self.dropout)
# rank dropout # rank dropout
if self.rank_dropout is not None and self.training: if self.rank_dropout is not None and self.training:
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
if len(lx.size()) == 3: if len(lx.size()) == 3:
mask = mask.unsqueeze(1) # for Text Encoder mask = mask.unsqueeze(1) # for Text Encoder
elif len(lx.size()) == 4: elif len(lx.size()) == 4:
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
lx = lx * mask lx = lx * mask
# scaling for rank dropout: treat as if the rank is changed # scaling for rank dropout: treat as if the rank is changed
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
else:
scale = self.scale
lx = self.lora_up(lx)
return org_forwarded + lx * self.multiplier * scale
else: else:
scale = self.scale lxs = [lora_down(x) for lora_down in self.lora_down]
lx = self.lora_up(lx) # normal dropout
if self.dropout is not None and self.training:
lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs]
return org_forwarded + lx * self.multiplier * scale # rank dropout
if self.rank_dropout is not None and self.training:
masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs]
for i in range(len(lxs)):
if len(lx.size()) == 3:
masks[i] = masks[i].unsqueeze(1)
elif len(lx.size()) == 4:
masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1)
lxs[i] = lxs[i] * masks[i]
# scaling for rank dropout: treat as if the rank is changed
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
else:
scale = self.scale
lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
class LoRAInfModule(LoRAModule): class LoRAInfModule(LoRAModule):
@@ -152,31 +194,50 @@ class LoRAInfModule(LoRAModule):
if device is None: if device is None:
device = org_device device = org_device
# get up/down weight if self.split_dims is None:
up_weight = sd["lora_up.weight"].to(torch.float).to(device) # get up/down weight
down_weight = sd["lora_down.weight"].to(torch.float).to(device) down_weight = sd["lora_down.weight"].to(torch.float).to(device)
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
# merge weight # merge weight
if len(weight.size()) == 2: if len(weight.size()) == 2:
# linear # linear
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
elif down_weight.size()[2:4] == (1, 1): elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1 # conv2d 1x1
weight = ( weight = (
weight weight
+ self.multiplier + self.multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* self.scale * self.scale
) )
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale
# set weight to org_module
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
else: else:
# conv2d 3x3 # split_dims
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) total_dims = sum(self.split_dims)
# logger.info(conved.size(), weight.size(), module.stride, module.padding) for i in range(len(self.split_dims)):
weight = weight + self.multiplier * conved * self.scale # get up/down weight
down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim)
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank)
# set weight to org_module # pad up_weight -> (total_dims, rank)
org_sd["weight"] = weight.to(dtype) padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
self.org_module.load_state_dict(org_sd) padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight
# merge weight
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
# set weight to org_module
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
# 復元できるマージのため、このモジュールのweightを返す # 復元できるマージのため、このモジュールのweightを返す
def get_weight(self, multiplier=None): def get_weight(self, multiplier=None):
@@ -211,7 +272,14 @@ class LoRAInfModule(LoRAModule):
def default_forward(self, x): def default_forward(self, x):
# logger.info(f"default_forward {self.lora_name} {x.size()}") # logger.info(f"default_forward {self.lora_name} {x.size()}")
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale if self.split_dims is None:
lx = self.lora_down(x)
lx = self.lora_up(lx)
return self.org_forward(x) + lx * self.multiplier * self.scale
else:
lxs = [lora_down(x) for lora_down in self.lora_down]
lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale
def forward(self, x): def forward(self, x):
if not self.enabled: if not self.enabled:
@@ -257,6 +325,11 @@ def create_network(
if train_blocks is not None: if train_blocks is not None:
assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}" assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}"
# split qkv
split_qkv = kwargs.get("split_qkv", False)
if split_qkv is not None:
split_qkv = True if split_qkv == "True" else False
# すごく引数が多いな ( ^ω^)・・・ # すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork( network = LoRANetwork(
text_encoders, text_encoders,
@@ -270,6 +343,7 @@ def create_network(
conv_lora_dim=conv_dim, conv_lora_dim=conv_dim,
conv_alpha=conv_alpha, conv_alpha=conv_alpha,
train_blocks=train_blocks, train_blocks=train_blocks,
split_qkv=split_qkv,
varbose=True, varbose=True,
) )
@@ -311,10 +385,34 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
modules_dim[lora_name] = dim modules_dim[lora_name] = dim
# logger.info(lora_name, value.size(), dim) # logger.info(lora_name, value.size(), dim)
# # split qkv
# double_qkv_rank = None
# single_qkv_rank = None
# rank = None
# for lora_name, dim in modules_dim.items():
# if "double" in lora_name and "qkv" in lora_name:
# double_qkv_rank = dim
# elif "single" in lora_name and "linear1" in lora_name:
# single_qkv_rank = dim
# elif rank is None:
# rank = dim
# if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None:
# break
# split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or (
# single_qkv_rank is not None and single_qkv_rank != rank
# )
split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined
module_class = LoRAInfModule if for_inference else LoRAModule module_class = LoRAInfModule if for_inference else LoRAModule
network = LoRANetwork( network = LoRANetwork(
text_encoders, flux, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class text_encoders,
flux,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
split_qkv=split_qkv,
) )
return network, weights_sd return network, weights_sd
@@ -344,6 +442,7 @@ class LoRANetwork(torch.nn.Module):
modules_dim: Optional[Dict[str, int]] = None, modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None,
train_blocks: Optional[str] = None, train_blocks: Optional[str] = None,
split_qkv: bool = False,
varbose: Optional[bool] = False, varbose: Optional[bool] = False,
) -> None: ) -> None:
super().__init__() super().__init__()
@@ -357,6 +456,7 @@ class LoRANetwork(torch.nn.Module):
self.rank_dropout = rank_dropout self.rank_dropout = rank_dropout
self.module_dropout = module_dropout self.module_dropout = module_dropout
self.train_blocks = train_blocks if train_blocks is not None else "all" self.train_blocks = train_blocks if train_blocks is not None else "all"
self.split_qkv = split_qkv
self.loraplus_lr_ratio = None self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None self.loraplus_unet_lr_ratio = None
@@ -373,6 +473,8 @@ class LoRANetwork(torch.nn.Module):
logger.info( logger.info(
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
) )
if self.split_qkv:
logger.info(f"split qkv for LoRA")
# create module instances # create module instances
def create_modules( def create_modules(
@@ -420,6 +522,14 @@ class LoRANetwork(torch.nn.Module):
skipped.append(lora_name) skipped.append(lora_name)
continue continue
# qkv split
split_dims = None
if is_flux and split_qkv:
if "double" in lora_name and "qkv" in lora_name:
split_dims = [3072] * 3
elif "single" in lora_name and "linear1" in lora_name:
split_dims = [3072] * 3 + [12288]
lora = module_class( lora = module_class(
lora_name, lora_name,
child_module, child_module,
@@ -429,6 +539,7 @@ class LoRANetwork(torch.nn.Module):
dropout=dropout, dropout=dropout,
rank_dropout=rank_dropout, rank_dropout=rank_dropout,
module_dropout=module_dropout, module_dropout=module_dropout,
split_dims=split_dims,
) )
loras.append(lora) loras.append(lora)
return loras, skipped return loras, skipped
@@ -492,6 +603,111 @@ class LoRANetwork(torch.nn.Module):
info = self.load_state_dict(weights_sd, False) info = self.load_state_dict(weights_sd, False)
return info return info
def load_state_dict(self, state_dict, strict=True):
# override to convert original weight to splitted qkv weight
if not self.split_qkv:
return super().load_state_dict(state_dict, strict)
# split qkv
for key in list(state_dict.keys()):
if "double" in key and "qkv" in key:
split_dims = [3072] * 3
elif "single" in key and "linear1" in key:
split_dims = [3072] * 3 + [12288]
else:
continue
weight = state_dict[key]
lora_name = key.split(".")[0]
if "lora_down" in key and "weight" in key:
# dense weight (rank*3, in_dim)
split_weight = torch.chunk(weight, len(split_dims), dim=0)
for i, split_w in enumerate(split_weight):
state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w
del state_dict[key]
# print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}")
elif "lora_up" in key and "weight" in key:
# sparse weight (out_dim=sum(split_dims), rank*3)
rank = weight.size(1) // len(split_dims)
i = 0
for j in range(len(split_dims)):
state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dims[j], j * rank : (j + 1) * rank]
i += split_dims[j]
del state_dict[key]
# # check is sparse
# i = 0
# is_zero = True
# for j in range(len(split_dims)):
# for k in range(len(split_dims)):
# if j == k:
# continue
# is_zero = is_zero and torch.all(weight[i : i + split_dims[j], k * rank : (k + 1) * rank] == 0)
# i += split_dims[j]
# if not is_zero:
# logger.warning(f"weight is not sparse: {key}")
# else:
# logger.info(f"weight is sparse: {key}")
# print(
# f"split {key}: {weight.shape} to {[state_dict[k].shape for k in [f'{lora_name}.lora_up.{j}.weight' for j in range(len(split_dims))]]}"
# )
# alpha is unchanged
return super().load_state_dict(state_dict, strict)
def state_dict(self, destination=None, prefix="", keep_vars=False):
if not self.split_qkv:
return super().state_dict(destination, prefix, keep_vars)
# merge qkv
state_dict = super().state_dict(destination, prefix, keep_vars)
new_state_dict = {}
for key in list(state_dict.keys()):
if "double" in key and "qkv" in key:
split_dims = [3072] * 3
elif "single" in key and "linear1" in key:
split_dims = [3072] * 3 + [12288]
else:
new_state_dict[key] = state_dict[key]
continue
if key not in state_dict:
continue # already merged
lora_name = key.split(".")[0]
# (rank, in_dim) * 3
down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))]
# (split dim, rank) * 3
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]
alpha = state_dict.pop(f"{lora_name}.alpha")
# merge down weight
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
# merge up weight (sum of split_dim, rank*3)
rank = up_weights[0].size(1)
up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
i = 0
for j in range(len(split_dims)):
up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j]
i += split_dims[j]
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
new_state_dict[f"{lora_name}.alpha"] = alpha
# print(
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
# )
print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha")
return new_state_dict
def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True): def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder: if apply_text_encoder:
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")