mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
feat: Add option to split projection layers and apply LoRA
This commit is contained in:
14
README.md
14
README.md
@@ -9,6 +9,20 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
|||||||
The command to install PyTorch is as follows:
|
The command to install PyTorch is as follows:
|
||||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||||
|
|
||||||
|
Aug 24, 2024 (update 2):
|
||||||
|
|
||||||
|
__Experimental__ Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them in FLUX.1 LoRA training. Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available).
|
||||||
|
|
||||||
|
The number of parameters may increase slightly, so the expressiveness may increase, but the training time may be longer. No detailed verification has been done.
|
||||||
|
|
||||||
|
This implementation is experimental, so it may be deprecated or changed in the future.
|
||||||
|
|
||||||
|
The .safetensors file of the trained model is compatible with the normal LoRA model of sd-scripts, so it should be usable in inference environments such as ComfyUI as it is. Also, converting it to AI-toolkit (Diffusers) format with `convert_flux_lora.py` will reduce the size. It should be no problem to convert it if you use it in the inference environment.
|
||||||
|
|
||||||
|
Technical details: In the implementation of Black Forest Labs' model, the projection layers of q/k/v (and txt in single blocks) are concatenated into one. If LoRA is added there as it is, the LoRA module is only one, and the dimension is large. In contrast, in the implementation of Diffusers, the projection layers of q/k/v/txt are separated. Therefore, the LoRA module is applied to q/k/v/txt separately, and the dimension is smaller. This option is for training LoRA similar to the latter.
|
||||||
|
|
||||||
|
The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large.
|
||||||
|
|
||||||
Aug 24, 2024:
|
Aug 24, 2024:
|
||||||
Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified.
|
Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified.
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ def main(file):
|
|||||||
|
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if "lora_up" in key or "lora_down" in key:
|
if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key:
|
||||||
values.append((key, sd[key]))
|
values.append((key, sd[key]))
|
||||||
print(f"number of LoRA modules: {len(values)}")
|
print(f"number of LoRA modules: {len(values)}")
|
||||||
|
|
||||||
|
|||||||
@@ -266,11 +266,12 @@ def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
|||||||
if sds_key + ".lora_down.weight" not in sds_sd:
|
if sds_key + ".lora_down.weight" not in sds_sd:
|
||||||
return
|
return
|
||||||
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
||||||
|
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
|
||||||
|
sd_lora_rank = down_weight.shape[0]
|
||||||
|
|
||||||
# scale weight by alpha and dim
|
# scale weight by alpha and dim
|
||||||
rank = down_weight.shape[0]
|
|
||||||
alpha = sds_sd.pop(sds_key + ".alpha")
|
alpha = sds_sd.pop(sds_key + ".alpha")
|
||||||
scale = alpha / rank
|
scale = alpha / sd_lora_rank
|
||||||
|
|
||||||
# calculate scale_down and scale_up
|
# calculate scale_down and scale_up
|
||||||
scale_down = scale
|
scale_down = scale
|
||||||
@@ -279,23 +280,49 @@ def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
|||||||
scale_down *= 2
|
scale_down *= 2
|
||||||
scale_up /= 2
|
scale_up /= 2
|
||||||
|
|
||||||
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
|
down_weight = down_weight * scale_down
|
||||||
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
up_weight = up_weight * scale_up
|
||||||
|
|
||||||
num_splits = len(ait_keys)
|
|
||||||
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
|
|
||||||
|
|
||||||
# down_weight is copied to each split
|
|
||||||
ait_sd.update({k: down_weight * scale_down for k in ait_down_keys})
|
|
||||||
|
|
||||||
# calculate dims if not provided
|
# calculate dims if not provided
|
||||||
|
num_splits = len(ait_keys)
|
||||||
if dims is None:
|
if dims is None:
|
||||||
dims = [up_weight.shape[0] // num_splits] * num_splits
|
dims = [up_weight.shape[0] // num_splits] * num_splits
|
||||||
else:
|
else:
|
||||||
assert sum(dims) == up_weight.shape[0]
|
assert sum(dims) == up_weight.shape[0]
|
||||||
|
|
||||||
# up_weight is split to each split
|
# check upweight is sparse or not
|
||||||
ait_sd.update({k: v * scale_up for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))})
|
is_sparse = False
|
||||||
|
if sd_lora_rank % num_splits == 0:
|
||||||
|
ait_rank = sd_lora_rank // num_splits
|
||||||
|
is_sparse = True
|
||||||
|
i = 0
|
||||||
|
for j in range(len(dims)):
|
||||||
|
for k in range(len(dims)):
|
||||||
|
if j == k:
|
||||||
|
continue
|
||||||
|
is_sparse = is_sparse and torch.all(up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0)
|
||||||
|
i += dims[j]
|
||||||
|
if is_sparse:
|
||||||
|
logger.info(f"weight is sparse: {sds_key}")
|
||||||
|
|
||||||
|
# make ai-toolkit weight
|
||||||
|
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
|
||||||
|
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
||||||
|
if not is_sparse:
|
||||||
|
# down_weight is copied to each split
|
||||||
|
ait_sd.update({k: down_weight for k in ait_down_keys})
|
||||||
|
|
||||||
|
# up_weight is split to each split
|
||||||
|
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))})
|
||||||
|
else:
|
||||||
|
# down_weight is chunked to each split
|
||||||
|
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))})
|
||||||
|
|
||||||
|
# up_weight is sparse: only non-zero values are copied to each split
|
||||||
|
i = 0
|
||||||
|
for j in range(len(dims)):
|
||||||
|
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
|
||||||
|
i += dims[j]
|
||||||
|
|
||||||
|
|
||||||
def convert_sd_scripts_to_ai_toolkit(sds_sd):
|
def convert_sd_scripts_to_ai_toolkit(sds_sd):
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ class LoRAModule(torch.nn.Module):
|
|||||||
dropout=None,
|
dropout=None,
|
||||||
rank_dropout=None,
|
rank_dropout=None,
|
||||||
module_dropout=None,
|
module_dropout=None,
|
||||||
|
split_dims: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -52,16 +53,34 @@ class LoRAModule(torch.nn.Module):
|
|||||||
out_dim = org_module.out_features
|
out_dim = org_module.out_features
|
||||||
|
|
||||||
self.lora_dim = lora_dim
|
self.lora_dim = lora_dim
|
||||||
|
self.split_dims = split_dims
|
||||||
|
|
||||||
if org_module.__class__.__name__ == "Conv2d":
|
if split_dims is None:
|
||||||
kernel_size = org_module.kernel_size
|
if org_module.__class__.__name__ == "Conv2d":
|
||||||
stride = org_module.stride
|
kernel_size = org_module.kernel_size
|
||||||
padding = org_module.padding
|
stride = org_module.stride
|
||||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
padding = org_module.padding
|
||||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||||
|
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||||
|
else:
|
||||||
|
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||||
|
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||||
|
|
||||||
|
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||||
|
torch.nn.init.zeros_(self.lora_up.weight)
|
||||||
else:
|
else:
|
||||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
# conv2d not supported
|
||||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
|
||||||
|
assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear"
|
||||||
|
# print(f"split_dims: {split_dims}")
|
||||||
|
self.lora_down = torch.nn.ModuleList(
|
||||||
|
[torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
|
||||||
|
)
|
||||||
|
self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
|
||||||
|
for lora_down in self.lora_down:
|
||||||
|
torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
|
||||||
|
for lora_up in self.lora_up:
|
||||||
|
torch.nn.init.zeros_(lora_up.weight)
|
||||||
|
|
||||||
if type(alpha) == torch.Tensor:
|
if type(alpha) == torch.Tensor:
|
||||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||||
@@ -70,9 +89,6 @@ class LoRAModule(torch.nn.Module):
|
|||||||
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||||||
|
|
||||||
# same as microsoft's
|
# same as microsoft's
|
||||||
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
|
||||||
torch.nn.init.zeros_(self.lora_up.weight)
|
|
||||||
|
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
self.org_module = org_module # remove in applying
|
self.org_module = org_module # remove in applying
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
@@ -92,30 +108,56 @@ class LoRAModule(torch.nn.Module):
|
|||||||
if torch.rand(1) < self.module_dropout:
|
if torch.rand(1) < self.module_dropout:
|
||||||
return org_forwarded
|
return org_forwarded
|
||||||
|
|
||||||
lx = self.lora_down(x)
|
if self.split_dims is None:
|
||||||
|
lx = self.lora_down(x)
|
||||||
|
|
||||||
# normal dropout
|
# normal dropout
|
||||||
if self.dropout is not None and self.training:
|
if self.dropout is not None and self.training:
|
||||||
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||||
|
|
||||||
# rank dropout
|
# rank dropout
|
||||||
if self.rank_dropout is not None and self.training:
|
if self.rank_dropout is not None and self.training:
|
||||||
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
||||||
if len(lx.size()) == 3:
|
if len(lx.size()) == 3:
|
||||||
mask = mask.unsqueeze(1) # for Text Encoder
|
mask = mask.unsqueeze(1) # for Text Encoder
|
||||||
elif len(lx.size()) == 4:
|
elif len(lx.size()) == 4:
|
||||||
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
||||||
lx = lx * mask
|
lx = lx * mask
|
||||||
|
|
||||||
# scaling for rank dropout: treat as if the rank is changed
|
# scaling for rank dropout: treat as if the rank is changed
|
||||||
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
||||||
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
||||||
|
else:
|
||||||
|
scale = self.scale
|
||||||
|
|
||||||
|
lx = self.lora_up(lx)
|
||||||
|
|
||||||
|
return org_forwarded + lx * self.multiplier * scale
|
||||||
else:
|
else:
|
||||||
scale = self.scale
|
lxs = [lora_down(x) for lora_down in self.lora_down]
|
||||||
|
|
||||||
lx = self.lora_up(lx)
|
# normal dropout
|
||||||
|
if self.dropout is not None and self.training:
|
||||||
|
lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs]
|
||||||
|
|
||||||
return org_forwarded + lx * self.multiplier * scale
|
# rank dropout
|
||||||
|
if self.rank_dropout is not None and self.training:
|
||||||
|
masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs]
|
||||||
|
for i in range(len(lxs)):
|
||||||
|
if len(lx.size()) == 3:
|
||||||
|
masks[i] = masks[i].unsqueeze(1)
|
||||||
|
elif len(lx.size()) == 4:
|
||||||
|
masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1)
|
||||||
|
lxs[i] = lxs[i] * masks[i]
|
||||||
|
|
||||||
|
# scaling for rank dropout: treat as if the rank is changed
|
||||||
|
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
||||||
|
else:
|
||||||
|
scale = self.scale
|
||||||
|
|
||||||
|
lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
|
||||||
|
|
||||||
|
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
|
||||||
|
|
||||||
|
|
||||||
class LoRAInfModule(LoRAModule):
|
class LoRAInfModule(LoRAModule):
|
||||||
@@ -152,31 +194,50 @@ class LoRAInfModule(LoRAModule):
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = org_device
|
device = org_device
|
||||||
|
|
||||||
# get up/down weight
|
if self.split_dims is None:
|
||||||
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
# get up/down weight
|
||||||
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
|
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
|
||||||
|
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
||||||
|
|
||||||
# merge weight
|
# merge weight
|
||||||
if len(weight.size()) == 2:
|
if len(weight.size()) == 2:
|
||||||
# linear
|
# linear
|
||||||
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
||||||
elif down_weight.size()[2:4] == (1, 1):
|
elif down_weight.size()[2:4] == (1, 1):
|
||||||
# conv2d 1x1
|
# conv2d 1x1
|
||||||
weight = (
|
weight = (
|
||||||
weight
|
weight
|
||||||
+ self.multiplier
|
+ self.multiplier
|
||||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
* self.scale
|
* self.scale
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# conv2d 3x3
|
||||||
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||||
|
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||||
|
weight = weight + self.multiplier * conved * self.scale
|
||||||
|
|
||||||
|
# set weight to org_module
|
||||||
|
org_sd["weight"] = weight.to(dtype)
|
||||||
|
self.org_module.load_state_dict(org_sd)
|
||||||
else:
|
else:
|
||||||
# conv2d 3x3
|
# split_dims
|
||||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
total_dims = sum(self.split_dims)
|
||||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
for i in range(len(self.split_dims)):
|
||||||
weight = weight + self.multiplier * conved * self.scale
|
# get up/down weight
|
||||||
|
down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim)
|
||||||
|
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank)
|
||||||
|
|
||||||
# set weight to org_module
|
# pad up_weight -> (total_dims, rank)
|
||||||
org_sd["weight"] = weight.to(dtype)
|
padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
|
||||||
self.org_module.load_state_dict(org_sd)
|
padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight
|
||||||
|
|
||||||
|
# merge weight
|
||||||
|
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
||||||
|
|
||||||
|
# set weight to org_module
|
||||||
|
org_sd["weight"] = weight.to(dtype)
|
||||||
|
self.org_module.load_state_dict(org_sd)
|
||||||
|
|
||||||
# 復元できるマージのため、このモジュールのweightを返す
|
# 復元できるマージのため、このモジュールのweightを返す
|
||||||
def get_weight(self, multiplier=None):
|
def get_weight(self, multiplier=None):
|
||||||
@@ -211,7 +272,14 @@ class LoRAInfModule(LoRAModule):
|
|||||||
|
|
||||||
def default_forward(self, x):
|
def default_forward(self, x):
|
||||||
# logger.info(f"default_forward {self.lora_name} {x.size()}")
|
# logger.info(f"default_forward {self.lora_name} {x.size()}")
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
if self.split_dims is None:
|
||||||
|
lx = self.lora_down(x)
|
||||||
|
lx = self.lora_up(lx)
|
||||||
|
return self.org_forward(x) + lx * self.multiplier * self.scale
|
||||||
|
else:
|
||||||
|
lxs = [lora_down(x) for lora_down in self.lora_down]
|
||||||
|
lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
|
||||||
|
return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
@@ -257,6 +325,11 @@ def create_network(
|
|||||||
if train_blocks is not None:
|
if train_blocks is not None:
|
||||||
assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}"
|
assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}"
|
||||||
|
|
||||||
|
# split qkv
|
||||||
|
split_qkv = kwargs.get("split_qkv", False)
|
||||||
|
if split_qkv is not None:
|
||||||
|
split_qkv = True if split_qkv == "True" else False
|
||||||
|
|
||||||
# すごく引数が多いな ( ^ω^)・・・
|
# すごく引数が多いな ( ^ω^)・・・
|
||||||
network = LoRANetwork(
|
network = LoRANetwork(
|
||||||
text_encoders,
|
text_encoders,
|
||||||
@@ -270,6 +343,7 @@ def create_network(
|
|||||||
conv_lora_dim=conv_dim,
|
conv_lora_dim=conv_dim,
|
||||||
conv_alpha=conv_alpha,
|
conv_alpha=conv_alpha,
|
||||||
train_blocks=train_blocks,
|
train_blocks=train_blocks,
|
||||||
|
split_qkv=split_qkv,
|
||||||
varbose=True,
|
varbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -311,10 +385,34 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
|||||||
modules_dim[lora_name] = dim
|
modules_dim[lora_name] = dim
|
||||||
# logger.info(lora_name, value.size(), dim)
|
# logger.info(lora_name, value.size(), dim)
|
||||||
|
|
||||||
|
# # split qkv
|
||||||
|
# double_qkv_rank = None
|
||||||
|
# single_qkv_rank = None
|
||||||
|
# rank = None
|
||||||
|
# for lora_name, dim in modules_dim.items():
|
||||||
|
# if "double" in lora_name and "qkv" in lora_name:
|
||||||
|
# double_qkv_rank = dim
|
||||||
|
# elif "single" in lora_name and "linear1" in lora_name:
|
||||||
|
# single_qkv_rank = dim
|
||||||
|
# elif rank is None:
|
||||||
|
# rank = dim
|
||||||
|
# if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None:
|
||||||
|
# break
|
||||||
|
# split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or (
|
||||||
|
# single_qkv_rank is not None and single_qkv_rank != rank
|
||||||
|
# )
|
||||||
|
split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined
|
||||||
|
|
||||||
module_class = LoRAInfModule if for_inference else LoRAModule
|
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||||
|
|
||||||
network = LoRANetwork(
|
network = LoRANetwork(
|
||||||
text_encoders, flux, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
text_encoders,
|
||||||
|
flux,
|
||||||
|
multiplier=multiplier,
|
||||||
|
modules_dim=modules_dim,
|
||||||
|
modules_alpha=modules_alpha,
|
||||||
|
module_class=module_class,
|
||||||
|
split_qkv=split_qkv,
|
||||||
)
|
)
|
||||||
return network, weights_sd
|
return network, weights_sd
|
||||||
|
|
||||||
@@ -344,6 +442,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
modules_dim: Optional[Dict[str, int]] = None,
|
modules_dim: Optional[Dict[str, int]] = None,
|
||||||
modules_alpha: Optional[Dict[str, int]] = None,
|
modules_alpha: Optional[Dict[str, int]] = None,
|
||||||
train_blocks: Optional[str] = None,
|
train_blocks: Optional[str] = None,
|
||||||
|
split_qkv: bool = False,
|
||||||
varbose: Optional[bool] = False,
|
varbose: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -357,6 +456,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
self.rank_dropout = rank_dropout
|
self.rank_dropout = rank_dropout
|
||||||
self.module_dropout = module_dropout
|
self.module_dropout = module_dropout
|
||||||
self.train_blocks = train_blocks if train_blocks is not None else "all"
|
self.train_blocks = train_blocks if train_blocks is not None else "all"
|
||||||
|
self.split_qkv = split_qkv
|
||||||
|
|
||||||
self.loraplus_lr_ratio = None
|
self.loraplus_lr_ratio = None
|
||||||
self.loraplus_unet_lr_ratio = None
|
self.loraplus_unet_lr_ratio = None
|
||||||
@@ -373,6 +473,8 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||||
)
|
)
|
||||||
|
if self.split_qkv:
|
||||||
|
logger.info(f"split qkv for LoRA")
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(
|
def create_modules(
|
||||||
@@ -420,6 +522,14 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
skipped.append(lora_name)
|
skipped.append(lora_name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# qkv split
|
||||||
|
split_dims = None
|
||||||
|
if is_flux and split_qkv:
|
||||||
|
if "double" in lora_name and "qkv" in lora_name:
|
||||||
|
split_dims = [3072] * 3
|
||||||
|
elif "single" in lora_name and "linear1" in lora_name:
|
||||||
|
split_dims = [3072] * 3 + [12288]
|
||||||
|
|
||||||
lora = module_class(
|
lora = module_class(
|
||||||
lora_name,
|
lora_name,
|
||||||
child_module,
|
child_module,
|
||||||
@@ -429,6 +539,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
rank_dropout=rank_dropout,
|
rank_dropout=rank_dropout,
|
||||||
module_dropout=module_dropout,
|
module_dropout=module_dropout,
|
||||||
|
split_dims=split_dims,
|
||||||
)
|
)
|
||||||
loras.append(lora)
|
loras.append(lora)
|
||||||
return loras, skipped
|
return loras, skipped
|
||||||
@@ -492,6 +603,111 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
info = self.load_state_dict(weights_sd, False)
|
info = self.load_state_dict(weights_sd, False)
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict, strict=True):
|
||||||
|
# override to convert original weight to splitted qkv weight
|
||||||
|
if not self.split_qkv:
|
||||||
|
return super().load_state_dict(state_dict, strict)
|
||||||
|
|
||||||
|
# split qkv
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
if "double" in key and "qkv" in key:
|
||||||
|
split_dims = [3072] * 3
|
||||||
|
elif "single" in key and "linear1" in key:
|
||||||
|
split_dims = [3072] * 3 + [12288]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
weight = state_dict[key]
|
||||||
|
lora_name = key.split(".")[0]
|
||||||
|
if "lora_down" in key and "weight" in key:
|
||||||
|
# dense weight (rank*3, in_dim)
|
||||||
|
split_weight = torch.chunk(weight, len(split_dims), dim=0)
|
||||||
|
for i, split_w in enumerate(split_weight):
|
||||||
|
state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w
|
||||||
|
|
||||||
|
del state_dict[key]
|
||||||
|
# print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}")
|
||||||
|
elif "lora_up" in key and "weight" in key:
|
||||||
|
# sparse weight (out_dim=sum(split_dims), rank*3)
|
||||||
|
rank = weight.size(1) // len(split_dims)
|
||||||
|
i = 0
|
||||||
|
for j in range(len(split_dims)):
|
||||||
|
state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dims[j], j * rank : (j + 1) * rank]
|
||||||
|
i += split_dims[j]
|
||||||
|
del state_dict[key]
|
||||||
|
|
||||||
|
# # check is sparse
|
||||||
|
# i = 0
|
||||||
|
# is_zero = True
|
||||||
|
# for j in range(len(split_dims)):
|
||||||
|
# for k in range(len(split_dims)):
|
||||||
|
# if j == k:
|
||||||
|
# continue
|
||||||
|
# is_zero = is_zero and torch.all(weight[i : i + split_dims[j], k * rank : (k + 1) * rank] == 0)
|
||||||
|
# i += split_dims[j]
|
||||||
|
# if not is_zero:
|
||||||
|
# logger.warning(f"weight is not sparse: {key}")
|
||||||
|
# else:
|
||||||
|
# logger.info(f"weight is sparse: {key}")
|
||||||
|
|
||||||
|
# print(
|
||||||
|
# f"split {key}: {weight.shape} to {[state_dict[k].shape for k in [f'{lora_name}.lora_up.{j}.weight' for j in range(len(split_dims))]]}"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# alpha is unchanged
|
||||||
|
|
||||||
|
return super().load_state_dict(state_dict, strict)
|
||||||
|
|
||||||
|
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||||
|
if not self.split_qkv:
|
||||||
|
return super().state_dict(destination, prefix, keep_vars)
|
||||||
|
|
||||||
|
# merge qkv
|
||||||
|
state_dict = super().state_dict(destination, prefix, keep_vars)
|
||||||
|
new_state_dict = {}
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
if "double" in key and "qkv" in key:
|
||||||
|
split_dims = [3072] * 3
|
||||||
|
elif "single" in key and "linear1" in key:
|
||||||
|
split_dims = [3072] * 3 + [12288]
|
||||||
|
else:
|
||||||
|
new_state_dict[key] = state_dict[key]
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key not in state_dict:
|
||||||
|
continue # already merged
|
||||||
|
|
||||||
|
lora_name = key.split(".")[0]
|
||||||
|
|
||||||
|
# (rank, in_dim) * 3
|
||||||
|
down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))]
|
||||||
|
# (split dim, rank) * 3
|
||||||
|
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]
|
||||||
|
|
||||||
|
alpha = state_dict.pop(f"{lora_name}.alpha")
|
||||||
|
|
||||||
|
# merge down weight
|
||||||
|
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
|
||||||
|
|
||||||
|
# merge up weight (sum of split_dim, rank*3)
|
||||||
|
rank = up_weights[0].size(1)
|
||||||
|
up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
|
||||||
|
i = 0
|
||||||
|
for j in range(len(split_dims)):
|
||||||
|
up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j]
|
||||||
|
i += split_dims[j]
|
||||||
|
|
||||||
|
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
|
||||||
|
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
|
||||||
|
new_state_dict[f"{lora_name}.alpha"] = alpha
|
||||||
|
|
||||||
|
# print(
|
||||||
|
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
|
||||||
|
# )
|
||||||
|
print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha")
|
||||||
|
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True):
|
def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True):
|
||||||
if apply_text_encoder:
|
if apply_text_encoder:
|
||||||
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
|
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
|
||||||
|
|||||||
Reference in New Issue
Block a user