format by black

This commit is contained in:
Kohya S
2023-08-03 20:14:04 +09:00
parent 89aae3e04f
commit db80c5a2e7

View File

@@ -1,4 +1,3 @@
import math import math
import argparse import argparse
import os import os
@@ -13,10 +12,10 @@ CLAMP_QUANTILE = 0.99
def load_state_dict(file_name, dtype): def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors': if os.path.splitext(file_name)[1] == ".safetensors":
sd = load_file(file_name) sd = load_file(file_name)
else: else:
sd = torch.load(file_name, map_location='cpu') sd = torch.load(file_name, map_location="cpu")
for key in list(sd.keys()): for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor: if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype) sd[key] = sd[key].to(dtype)
@@ -29,7 +28,7 @@ def save_to_file(file_name, state_dict, dtype):
if type(state_dict[key]) == torch.Tensor: if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype) state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors': if os.path.splitext(file_name)[1] == ".safetensors":
save_file(state_dict, file_name) save_file(state_dict, file_name)
else: else:
torch.save(state_dict, file_name) torch.save(state_dict, file_name)
@@ -45,7 +44,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
# merge # merge
print(f"merging...") print(f"merging...")
for key in tqdm(list(lora_sd.keys())): for key in tqdm(list(lora_sd.keys())):
if 'lora_down' not in key: if "lora_down" not in key:
continue continue
lora_module_name = key[: key.rfind(".lora_down")] lora_module_name = key[: key.rfind(".lora_down")]
@@ -53,8 +52,8 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
down_weight = lora_sd[key] down_weight = lora_sd[key]
network_dim = down_weight.size()[0] network_dim = down_weight.size()[0]
up_weight = lora_sd[lora_module_name + '.lora_up.weight'] up_weight = lora_sd[lora_module_name + ".lora_up.weight"]
alpha = lora_sd.get(lora_module_name + '.alpha', network_dim) alpha = lora_sd.get(lora_module_name + ".alpha", network_dim)
in_dim = down_weight.size()[1] in_dim = down_weight.size()[1]
out_dim = up_weight.size()[0] out_dim = up_weight.size()[0]
@@ -76,7 +75,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
down_weight = down_weight.to(device) down_weight = down_weight.to(device)
# W <- W + U * D # W <- W + U * D
scale = (alpha / network_dim) scale = alpha / network_dim
if device: # and isinstance(scale, torch.Tensor): if device: # and isinstance(scale, torch.Tensor):
scale = scale.to(device) scale = scale.to(device)
@@ -84,8 +83,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
if not conv2d: # linear if not conv2d: # linear
weight = weight + ratio * (up_weight @ down_weight) * scale weight = weight + ratio * (up_weight @ down_weight) * scale
elif kernel_size == (1, 1): elif kernel_size == (1, 1):
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) weight = (
).unsqueeze(2).unsqueeze(3) * scale weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else: else:
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = weight + ratio * conved * scale weight = weight + ratio * conved * scale
@@ -97,7 +100,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
merged_lora_sd = {} merged_lora_sd = {}
with torch.no_grad(): with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())): for lora_module_name, mat in tqdm(list(merged_sd.items())):
conv2d = (len(mat.size()) == 4) conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4] kernel_size = None if not conv2d else mat.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1) conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = mat.size()[0:2] out_dim, in_dim = mat.size()[0:2]
@@ -133,9 +136,9 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
up_weight = U up_weight = U
down_weight = Vh down_weight = Vh
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous() merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous() merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(module_new_rank) merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
return merged_lora_sd return merged_lora_sd
@@ -144,11 +147,11 @@ def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p): def str_to_dtype(p):
if p == 'float': if p == "float":
return torch.float return torch.float
if p == 'fp16': if p == "fp16":
return torch.float16 return torch.float16
if p == 'bf16': if p == "bf16":
return torch.bfloat16 return torch.bfloat16
return None return None
@@ -166,26 +169,40 @@ def merge(args):
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--save_precision", type=str, default=None, parser.add_argument(
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") "--save_precision",
parser.add_argument("--precision", type=str, default="float", type=str,
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨") default=None,
parser.add_argument("--save_to", type=str, default=None, choices=[None, "float", "fp16", "bf16"],
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
parser.add_argument("--models", type=str, nargs='*', )
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") parser.add_argument(
parser.add_argument("--ratios", type=float, nargs='*', "--precision",
help="ratios for each model / それぞれのLoRAモデルの比率") type=str,
parser.add_argument("--new_rank", type=int, default=4, default="float",
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") choices=["float", "fp16", "bf16"],
parser.add_argument("--new_conv_rank", type=int, default=None, help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨",
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ") )
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
)
parser.add_argument(
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument(
"--new_conv_rank",
type=int,
default=None,
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
)
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
return parser return parser
if __name__ == '__main__': if __name__ == "__main__":
parser = setup_parser() parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()