mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
format by black
This commit is contained in:
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
@@ -13,10 +12,10 @@ CLAMP_QUANTILE = 0.99
|
|||||||
|
|
||||||
|
|
||||||
def load_state_dict(file_name, dtype):
|
def load_state_dict(file_name, dtype):
|
||||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||||
sd = load_file(file_name)
|
sd = load_file(file_name)
|
||||||
else:
|
else:
|
||||||
sd = torch.load(file_name, map_location='cpu')
|
sd = torch.load(file_name, map_location="cpu")
|
||||||
for key in list(sd.keys()):
|
for key in list(sd.keys()):
|
||||||
if type(sd[key]) == torch.Tensor:
|
if type(sd[key]) == torch.Tensor:
|
||||||
sd[key] = sd[key].to(dtype)
|
sd[key] = sd[key].to(dtype)
|
||||||
@@ -29,7 +28,7 @@ def save_to_file(file_name, state_dict, dtype):
|
|||||||
if type(state_dict[key]) == torch.Tensor:
|
if type(state_dict[key]) == torch.Tensor:
|
||||||
state_dict[key] = state_dict[key].to(dtype)
|
state_dict[key] = state_dict[key].to(dtype)
|
||||||
|
|
||||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||||
save_file(state_dict, file_name)
|
save_file(state_dict, file_name)
|
||||||
else:
|
else:
|
||||||
torch.save(state_dict, file_name)
|
torch.save(state_dict, file_name)
|
||||||
@@ -45,7 +44,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|||||||
# merge
|
# merge
|
||||||
print(f"merging...")
|
print(f"merging...")
|
||||||
for key in tqdm(list(lora_sd.keys())):
|
for key in tqdm(list(lora_sd.keys())):
|
||||||
if 'lora_down' not in key:
|
if "lora_down" not in key:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
lora_module_name = key[: key.rfind(".lora_down")]
|
lora_module_name = key[: key.rfind(".lora_down")]
|
||||||
@@ -53,8 +52,8 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|||||||
down_weight = lora_sd[key]
|
down_weight = lora_sd[key]
|
||||||
network_dim = down_weight.size()[0]
|
network_dim = down_weight.size()[0]
|
||||||
|
|
||||||
up_weight = lora_sd[lora_module_name + '.lora_up.weight']
|
up_weight = lora_sd[lora_module_name + ".lora_up.weight"]
|
||||||
alpha = lora_sd.get(lora_module_name + '.alpha', network_dim)
|
alpha = lora_sd.get(lora_module_name + ".alpha", network_dim)
|
||||||
|
|
||||||
in_dim = down_weight.size()[1]
|
in_dim = down_weight.size()[1]
|
||||||
out_dim = up_weight.size()[0]
|
out_dim = up_weight.size()[0]
|
||||||
@@ -76,7 +75,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|||||||
down_weight = down_weight.to(device)
|
down_weight = down_weight.to(device)
|
||||||
|
|
||||||
# W <- W + U * D
|
# W <- W + U * D
|
||||||
scale = (alpha / network_dim)
|
scale = alpha / network_dim
|
||||||
|
|
||||||
if device: # and isinstance(scale, torch.Tensor):
|
if device: # and isinstance(scale, torch.Tensor):
|
||||||
scale = scale.to(device)
|
scale = scale.to(device)
|
||||||
@@ -84,8 +83,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|||||||
if not conv2d: # linear
|
if not conv2d: # linear
|
||||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||||
elif kernel_size == (1, 1):
|
elif kernel_size == (1, 1):
|
||||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
weight = (
|
||||||
).unsqueeze(2).unsqueeze(3) * scale
|
weight
|
||||||
|
+ ratio
|
||||||
|
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
* scale
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||||
weight = weight + ratio * conved * scale
|
weight = weight + ratio * conved * scale
|
||||||
@@ -97,7 +100,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|||||||
merged_lora_sd = {}
|
merged_lora_sd = {}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
||||||
conv2d = (len(mat.size()) == 4)
|
conv2d = len(mat.size()) == 4
|
||||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||||
out_dim, in_dim = mat.size()[0:2]
|
out_dim, in_dim = mat.size()[0:2]
|
||||||
@@ -133,9 +136,9 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|||||||
up_weight = U
|
up_weight = U
|
||||||
down_weight = Vh
|
down_weight = Vh
|
||||||
|
|
||||||
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
|
||||||
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
|
||||||
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(module_new_rank)
|
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
|
||||||
|
|
||||||
return merged_lora_sd
|
return merged_lora_sd
|
||||||
|
|
||||||
@@ -144,11 +147,11 @@ def merge(args):
|
|||||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||||
|
|
||||||
def str_to_dtype(p):
|
def str_to_dtype(p):
|
||||||
if p == 'float':
|
if p == "float":
|
||||||
return torch.float
|
return torch.float
|
||||||
if p == 'fp16':
|
if p == "fp16":
|
||||||
return torch.float16
|
return torch.float16
|
||||||
if p == 'bf16':
|
if p == "bf16":
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -166,26 +169,40 @@ def merge(args):
|
|||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--save_precision", type=str, default=None,
|
parser.add_argument(
|
||||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
|
"--save_precision",
|
||||||
parser.add_argument("--precision", type=str, default="float",
|
type=str,
|
||||||
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
|
default=None,
|
||||||
parser.add_argument("--save_to", type=str, default=None,
|
choices=[None, "float", "fp16", "bf16"],
|
||||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
|
||||||
parser.add_argument("--models", type=str, nargs='*',
|
)
|
||||||
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
|
parser.add_argument(
|
||||||
parser.add_argument("--ratios", type=float, nargs='*',
|
"--precision",
|
||||||
help="ratios for each model / それぞれのLoRAモデルの比率")
|
type=str,
|
||||||
parser.add_argument("--new_rank", type=int, default=4,
|
default="float",
|
||||||
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
choices=["float", "fp16", "bf16"],
|
||||||
parser.add_argument("--new_conv_rank", type=int, default=None,
|
help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
|
||||||
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
|
||||||
|
)
|
||||||
|
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||||
|
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--new_conv_rank",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
|
||||||
|
)
|
||||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = setup_parser()
|
parser = setup_parser()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
Reference in New Issue
Block a user