Support for multiple format loras.

This commit is contained in:
Symbiomatrix
2025-04-20 15:08:21 +03:00
committed by woctordho
parent a21b6a917e
commit 63ec59fc0b

View File

@@ -20,6 +20,12 @@ logger = logging.getLogger(__name__)
MIN_SV = 1e-6
# Tune layers to various trainer formats.
LORAFMT1 = ["lora_down", "lora_up"]
LORAFMT2 = ["lora.down", "lora.up"]
LORAFMT3 = ["lora_A", "lora_B"]
LORAFMT = LORAFMT1
# Model save and load functions
@@ -90,8 +96,8 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
param_dict[LORAFMT[0]] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
param_dict[LORAFMT[1]] = U.reshape(out_size, lora_rank, 1, 1).cpu()
del U, S, Vh, weight
return param_dict
@@ -109,8 +115,8 @@ def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, sca
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
param_dict[LORAFMT[0]] = Vh.reshape(lora_rank, in_size).cpu()
param_dict[LORAFMT[1]] = U.reshape(out_size, lora_rank).cpu()
del U, S, Vh, weight
return param_dict
@@ -192,6 +198,7 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
global LORAFMT
network_alpha = None
network_dim = None
verbose_str = "\n"
@@ -201,7 +208,14 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
for key, value in lora_sd.items():
if network_alpha is None and "alpha" in key:
network_alpha = value
if network_dim is None and "lora_down" in key and len(value.size()) == 2:
if (network_dim is None and len(value.size()) == 2
and (LORAFMT1[0] in key or LORAFMT2[0] in key or LORAFMT3[0] in key)):
if LORAFMT1[0] in key:
LORAFMT = LORAFMT1
elif LORAFMT2[0] in key:
LORAFMT = LORAFMT2
elif LORAFMT3[0] in key:
LORAFMT = LORAFMT3
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
@@ -225,8 +239,8 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
weight_name = None
if "lora_down" in key:
block_down_name = key.rsplit(".lora_down", 1)[0]
if LORAFMT[0] in key:
block_down_name = key.rsplit(f".LORAFMT[0]", 1)[0]
weight_name = key.rsplit(".", 1)[-1]
lora_down_weight = value
else:
@@ -234,7 +248,7 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
# find corresponding lora_up and alpha
block_up_name = block_down_name
lora_up_weight = lora_sd.get(block_up_name + ".lora_up." + weight_name, None)
lora_up_weight = lora_sd.get(block_up_name + f".LORAFMT[1]." + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
weights_loaded = lora_down_weight is not None and lora_up_weight is not None
@@ -272,9 +286,9 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
verbose_str += "\n"
new_alpha = param_dict["new_alpha"]
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
o_lora_sd[block_down_name + f".LORAFMT[0].weight"] = param_dict[LORAFMT[0]].to(save_dtype).contiguous()
o_lora_sd[block_up_name + f".LORAFMT[1].weight"] = param_dict[LORAFMT[1]].to(save_dtype).contiguous()
o_lora_sd[block_up_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
block_down_name = None
block_up_name = None