feat: Add option to split projection layers and apply LoRA

This commit is contained in:
Kohya S
2024-08-24 16:35:43 +09:00
parent 2e89cd2cc6
commit cf689e7aa6
4 changed files with 323 additions and 66 deletions

View File

@@ -18,7 +18,7 @@ def main(file):
keys = list(sd.keys())
for key in keys:
if "lora_up" in key or "lora_down" in key:
if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key:
values.append((key, sd[key]))
print(f"number of LoRA modules: {len(values)}")

View File

@@ -266,11 +266,12 @@ def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
sd_lora_rank = down_weight.shape[0]
# scale weight by alpha and dim
rank = down_weight.shape[0]
alpha = sds_sd.pop(sds_key + ".alpha")
scale = alpha / rank
scale = alpha / sd_lora_rank
# calculate scale_down and scale_up
scale_down = scale
@@ -279,23 +280,49 @@ def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
scale_down *= 2
scale_up /= 2
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
num_splits = len(ait_keys)
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
# down_weight is copied to each split
ait_sd.update({k: down_weight * scale_down for k in ait_down_keys})
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up
# calculate dims if not provided
num_splits = len(ait_keys)
if dims is None:
dims = [up_weight.shape[0] // num_splits] * num_splits
else:
assert sum(dims) == up_weight.shape[0]
# up_weight is split to each split
ait_sd.update({k: v * scale_up for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))})
# check upweight is sparse or not
is_sparse = False
if sd_lora_rank % num_splits == 0:
ait_rank = sd_lora_rank // num_splits
is_sparse = True
i = 0
for j in range(len(dims)):
for k in range(len(dims)):
if j == k:
continue
is_sparse = is_sparse and torch.all(up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0)
i += dims[j]
if is_sparse:
logger.info(f"weight is sparse: {sds_key}")
# make ai-toolkit weight
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
if not is_sparse:
# down_weight is copied to each split
ait_sd.update({k: down_weight for k in ait_down_keys})
# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))})
else:
# down_weight is chunked to each split
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))})
# up_weight is sparse: only non-zero values are copied to each split
i = 0
for j in range(len(dims)):
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
i += dims[j]
def convert_sd_scripts_to_ai_toolkit(sds_sd):

View File

@@ -39,6 +39,7 @@ class LoRAModule(torch.nn.Module):
dropout=None,
rank_dropout=None,
module_dropout=None,
split_dims: Optional[List[int]] = None,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
@@ -52,16 +53,34 @@ class LoRAModule(torch.nn.Module):
out_dim = org_module.out_features
self.lora_dim = lora_dim
self.split_dims = split_dims
if org_module.__class__.__name__ == "Conv2d":
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
if split_dims is None:
if org_module.__class__.__name__ == "Conv2d":
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
# conv2d not supported
assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear"
# print(f"split_dims: {split_dims}")
self.lora_down = torch.nn.ModuleList(
[torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
)
self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
for lora_down in self.lora_down:
torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
for lora_up in self.lora_up:
torch.nn.init.zeros_(lora_up.weight)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
@@ -70,9 +89,6 @@ class LoRAModule(torch.nn.Module):
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
self.multiplier = multiplier
self.org_module = org_module # remove in applying
self.dropout = dropout
@@ -92,30 +108,56 @@ class LoRAModule(torch.nn.Module):
if torch.rand(1) < self.module_dropout:
return org_forwarded
lx = self.lora_down(x)
if self.split_dims is None:
lx = self.lora_down(x)
# normal dropout
if self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(lx, p=self.dropout)
# normal dropout
if self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(lx, p=self.dropout)
# rank dropout
if self.rank_dropout is not None and self.training:
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
if len(lx.size()) == 3:
mask = mask.unsqueeze(1) # for Text Encoder
elif len(lx.size()) == 4:
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
lx = lx * mask
# rank dropout
if self.rank_dropout is not None and self.training:
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
if len(lx.size()) == 3:
mask = mask.unsqueeze(1) # for Text Encoder
elif len(lx.size()) == 4:
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
lx = lx * mask
# scaling for rank dropout: treat as if the rank is changed
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
# scaling for rank dropout: treat as if the rank is changed
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
else:
scale = self.scale
lx = self.lora_up(lx)
return org_forwarded + lx * self.multiplier * scale
else:
scale = self.scale
lxs = [lora_down(x) for lora_down in self.lora_down]
lx = self.lora_up(lx)
# normal dropout
if self.dropout is not None and self.training:
lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs]
return org_forwarded + lx * self.multiplier * scale
# rank dropout
if self.rank_dropout is not None and self.training:
masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs]
for i in range(len(lxs)):
if len(lx.size()) == 3:
masks[i] = masks[i].unsqueeze(1)
elif len(lx.size()) == 4:
masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1)
lxs[i] = lxs[i] * masks[i]
# scaling for rank dropout: treat as if the rank is changed
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
else:
scale = self.scale
lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
class LoRAInfModule(LoRAModule):
@@ -152,31 +194,50 @@ class LoRAInfModule(LoRAModule):
if device is None:
device = org_device
# get up/down weight
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
if self.split_dims is None:
# get up/down weight
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
# merge weight
if len(weight.size()) == 2:
# linear
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ self.multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* self.scale
)
# merge weight
if len(weight.size()) == 2:
# linear
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ self.multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* self.scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale
# set weight to org_module
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale
# split_dims
total_dims = sum(self.split_dims)
for i in range(len(self.split_dims)):
# get up/down weight
down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim)
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank)
# set weight to org_module
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
# pad up_weight -> (total_dims, rank)
padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight
# merge weight
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
# set weight to org_module
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
# 復元できるマージのため、このモジュールのweightを返す
def get_weight(self, multiplier=None):
@@ -211,7 +272,14 @@ class LoRAInfModule(LoRAModule):
def default_forward(self, x):
# logger.info(f"default_forward {self.lora_name} {x.size()}")
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
if self.split_dims is None:
lx = self.lora_down(x)
lx = self.lora_up(lx)
return self.org_forward(x) + lx * self.multiplier * self.scale
else:
lxs = [lora_down(x) for lora_down in self.lora_down]
lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale
def forward(self, x):
if not self.enabled:
@@ -257,6 +325,11 @@ def create_network(
if train_blocks is not None:
assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}"
# split qkv
split_qkv = kwargs.get("split_qkv", False)
if split_qkv is not None:
split_qkv = True if split_qkv == "True" else False
# すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork(
text_encoders,
@@ -270,6 +343,7 @@ def create_network(
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
train_blocks=train_blocks,
split_qkv=split_qkv,
varbose=True,
)
@@ -311,10 +385,34 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
modules_dim[lora_name] = dim
# logger.info(lora_name, value.size(), dim)
# # split qkv
# double_qkv_rank = None
# single_qkv_rank = None
# rank = None
# for lora_name, dim in modules_dim.items():
# if "double" in lora_name and "qkv" in lora_name:
# double_qkv_rank = dim
# elif "single" in lora_name and "linear1" in lora_name:
# single_qkv_rank = dim
# elif rank is None:
# rank = dim
# if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None:
# break
# split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or (
# single_qkv_rank is not None and single_qkv_rank != rank
# )
split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined
module_class = LoRAInfModule if for_inference else LoRAModule
network = LoRANetwork(
text_encoders, flux, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
text_encoders,
flux,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
split_qkv=split_qkv,
)
return network, weights_sd
@@ -344,6 +442,7 @@ class LoRANetwork(torch.nn.Module):
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
train_blocks: Optional[str] = None,
split_qkv: bool = False,
varbose: Optional[bool] = False,
) -> None:
super().__init__()
@@ -357,6 +456,7 @@ class LoRANetwork(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.train_blocks = train_blocks if train_blocks is not None else "all"
self.split_qkv = split_qkv
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
@@ -373,6 +473,8 @@ class LoRANetwork(torch.nn.Module):
logger.info(
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
)
if self.split_qkv:
logger.info(f"split qkv for LoRA")
# create module instances
def create_modules(
@@ -420,6 +522,14 @@ class LoRANetwork(torch.nn.Module):
skipped.append(lora_name)
continue
# qkv split
split_dims = None
if is_flux and split_qkv:
if "double" in lora_name and "qkv" in lora_name:
split_dims = [3072] * 3
elif "single" in lora_name and "linear1" in lora_name:
split_dims = [3072] * 3 + [12288]
lora = module_class(
lora_name,
child_module,
@@ -429,6 +539,7 @@ class LoRANetwork(torch.nn.Module):
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
split_dims=split_dims,
)
loras.append(lora)
return loras, skipped
@@ -492,6 +603,111 @@ class LoRANetwork(torch.nn.Module):
info = self.load_state_dict(weights_sd, False)
return info
def load_state_dict(self, state_dict, strict=True):
# override to convert original weight to splitted qkv weight
if not self.split_qkv:
return super().load_state_dict(state_dict, strict)
# split qkv
for key in list(state_dict.keys()):
if "double" in key and "qkv" in key:
split_dims = [3072] * 3
elif "single" in key and "linear1" in key:
split_dims = [3072] * 3 + [12288]
else:
continue
weight = state_dict[key]
lora_name = key.split(".")[0]
if "lora_down" in key and "weight" in key:
# dense weight (rank*3, in_dim)
split_weight = torch.chunk(weight, len(split_dims), dim=0)
for i, split_w in enumerate(split_weight):
state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w
del state_dict[key]
# print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}")
elif "lora_up" in key and "weight" in key:
# sparse weight (out_dim=sum(split_dims), rank*3)
rank = weight.size(1) // len(split_dims)
i = 0
for j in range(len(split_dims)):
state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dims[j], j * rank : (j + 1) * rank]
i += split_dims[j]
del state_dict[key]
# # check is sparse
# i = 0
# is_zero = True
# for j in range(len(split_dims)):
# for k in range(len(split_dims)):
# if j == k:
# continue
# is_zero = is_zero and torch.all(weight[i : i + split_dims[j], k * rank : (k + 1) * rank] == 0)
# i += split_dims[j]
# if not is_zero:
# logger.warning(f"weight is not sparse: {key}")
# else:
# logger.info(f"weight is sparse: {key}")
# print(
# f"split {key}: {weight.shape} to {[state_dict[k].shape for k in [f'{lora_name}.lora_up.{j}.weight' for j in range(len(split_dims))]]}"
# )
# alpha is unchanged
return super().load_state_dict(state_dict, strict)
def state_dict(self, destination=None, prefix="", keep_vars=False):
if not self.split_qkv:
return super().state_dict(destination, prefix, keep_vars)
# merge qkv
state_dict = super().state_dict(destination, prefix, keep_vars)
new_state_dict = {}
for key in list(state_dict.keys()):
if "double" in key and "qkv" in key:
split_dims = [3072] * 3
elif "single" in key and "linear1" in key:
split_dims = [3072] * 3 + [12288]
else:
new_state_dict[key] = state_dict[key]
continue
if key not in state_dict:
continue # already merged
lora_name = key.split(".")[0]
# (rank, in_dim) * 3
down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))]
# (split dim, rank) * 3
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]
alpha = state_dict.pop(f"{lora_name}.alpha")
# merge down weight
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
# merge up weight (sum of split_dim, rank*3)
rank = up_weights[0].size(1)
up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
i = 0
for j in range(len(split_dims)):
up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j]
i += split_dims[j]
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
new_state_dict[f"{lora_name}.alpha"] = alpha
# print(
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
# )
print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha")
return new_state_dict
def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")