diff --git a/networks/resize_lora.py b/networks/resize_lora.py index c7418a5b..60c86f25 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -6,6 +6,7 @@ import argparse import os import torch from safetensors.torch import load_file, save_file +from tqdm import tqdm def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == '.safetensors': @@ -65,7 +66,7 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype): print("resizing lora...") with torch.no_grad(): - for key, value in lora_sd.items(): + for key, value in tqdm(lora_sd.items()): if 'lora_down' in key: block_down_name = key.split(".")[0] lora_down_weight = value @@ -83,6 +84,11 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype): lora_down_weight = lora_down_weight.squeeze() lora_up_weight = lora_up_weight.squeeze() + if args.device: + org_device = lora_up_weight.device + lora_up_weight = lora_up_weight.to(args.device) + lora_down_weight = lora_down_weight.to(args.device) + full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight) U, S, Vh = torch.linalg.svd(full_weight_matrix) @@ -103,6 +109,10 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype): if conv2d: U = U.unsqueeze(2).unsqueeze(3) Vh = Vh.unsqueeze(2).unsqueeze(3) + + if args.device: + U = U.to(org_device) + Vh = Vh.to(org_device) o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous() o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous() @@ -143,13 +153,14 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if ommitted") + choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if ommitted / 保存時の精度、未指定時はfloat") parser.add_argument("--new_rank", type=int, default=4, - help="Specify rank of output LoRA") + help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") parser.add_argument("--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") parser.add_argument("--model", type=str, default=None, - help="LoRA model to resize at to new rank: ckpt or safetensors file") + help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors") + parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") args = parser.parse_args() resize(args)