mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
117 lines
3.6 KiB
Python
117 lines
3.6 KiB
Python
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
|
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
|
# Thanks to cloneofsimo
|
|
|
|
import argparse
|
|
import os
|
|
import torch
|
|
from safetensors.torch import load_file, save_file, safe_open
|
|
from tqdm import tqdm
|
|
from library import train_util, model_util
|
|
import numpy as np
|
|
|
|
|
|
def load_state_dict(file_name):
|
|
if model_util.is_safetensors(file_name):
|
|
sd = load_file(file_name)
|
|
with safe_open(file_name, framework="pt") as f:
|
|
metadata = f.metadata()
|
|
else:
|
|
sd = torch.load(file_name, map_location="cpu")
|
|
metadata = None
|
|
|
|
return sd, metadata
|
|
|
|
|
|
def save_to_file(file_name, model, metadata):
|
|
if model_util.is_safetensors(file_name):
|
|
save_file(model, file_name, metadata)
|
|
else:
|
|
torch.save(model, file_name)
|
|
|
|
|
|
def split_lora_model(lora_sd, unit):
|
|
max_rank = 0
|
|
|
|
# Extract loaded lora dim and alpha
|
|
for key, value in lora_sd.items():
|
|
if "lora_down" in key:
|
|
rank = value.size()[0]
|
|
if rank > max_rank:
|
|
max_rank = rank
|
|
print(f"Max rank: {max_rank}")
|
|
|
|
rank = unit
|
|
splitted_models = []
|
|
while rank < max_rank:
|
|
print(f"Splitting rank {rank}")
|
|
new_sd = {}
|
|
for key, value in lora_sd.items():
|
|
if "lora_down" in key:
|
|
new_sd[key] = value[:rank].contiguous()
|
|
elif "lora_up" in key:
|
|
new_sd[key] = value[:, :rank].contiguous()
|
|
else:
|
|
new_sd[key] = value # alpha and other parameters
|
|
|
|
splitted_models.append((new_sd, rank))
|
|
rank += unit
|
|
|
|
return max_rank, splitted_models
|
|
|
|
|
|
def split(args):
|
|
print("loading Model...")
|
|
lora_sd, metadata = load_state_dict(args.model)
|
|
|
|
print("Splitting Model...")
|
|
original_rank, splitted_models = split_lora_model(lora_sd, args.unit)
|
|
|
|
comment = metadata.get("ss_training_comment", "")
|
|
for state_dict, new_rank in splitted_models:
|
|
# update metadata
|
|
if metadata is None:
|
|
new_metadata = {}
|
|
else:
|
|
new_metadata = metadata.copy()
|
|
|
|
new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}"
|
|
new_metadata["ss_network_dim"] = str(new_rank)
|
|
|
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
|
metadata["sshs_model_hash"] = model_hash
|
|
metadata["sshs_legacy_hash"] = legacy_hash
|
|
|
|
filename, ext = os.path.splitext(args.save_to)
|
|
model_file_name = filename + f"-{new_rank:04d}{ext}"
|
|
|
|
print(f"saving model to: {model_file_name}")
|
|
save_to_file(model_file_name, state_dict, new_metadata)
|
|
|
|
|
|
def setup_parser() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ")
|
|
parser.add_argument(
|
|
"--save_to",
|
|
type=str,
|
|
default=None,
|
|
help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors",
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
default=None,
|
|
help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors",
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = setup_parser()
|
|
|
|
args = parser.parse_args()
|
|
split(args)
|