Add device option to calculate on GPU

This commit is contained in:
Kohya S
2023-02-04 20:36:10 +09:00
parent b18db9fbbd
commit 8cbd3f4fca

View File

@@ -6,6 +6,7 @@ import argparse
import os import os
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from tqdm import tqdm
def load_state_dict(file_name, dtype): def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors': if os.path.splitext(file_name)[1] == '.safetensors':
@@ -65,7 +66,7 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
print("resizing lora...") print("resizing lora...")
with torch.no_grad(): with torch.no_grad():
for key, value in lora_sd.items(): for key, value in tqdm(lora_sd.items()):
if 'lora_down' in key: if 'lora_down' in key:
block_down_name = key.split(".")[0] block_down_name = key.split(".")[0]
lora_down_weight = value lora_down_weight = value
@@ -83,6 +84,11 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
lora_down_weight = lora_down_weight.squeeze() lora_down_weight = lora_down_weight.squeeze()
lora_up_weight = lora_up_weight.squeeze() lora_up_weight = lora_up_weight.squeeze()
if args.device:
org_device = lora_up_weight.device
lora_up_weight = lora_up_weight.to(args.device)
lora_down_weight = lora_down_weight.to(args.device)
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight) full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
U, S, Vh = torch.linalg.svd(full_weight_matrix) U, S, Vh = torch.linalg.svd(full_weight_matrix)
@@ -104,6 +110,10 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
U = U.unsqueeze(2).unsqueeze(3) U = U.unsqueeze(2).unsqueeze(3)
Vh = Vh.unsqueeze(2).unsqueeze(3) Vh = Vh.unsqueeze(2).unsqueeze(3)
if args.device:
U = U.to(org_device)
Vh = Vh.to(org_device)
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous() o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous() o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype) o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
@@ -143,13 +153,14 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--save_precision", type=str, default=None, parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if ommitted") choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if ommitted / 保存時の精度、未指定時はfloat")
parser.add_argument("--new_rank", type=int, default=4, parser.add_argument("--new_rank", type=int, default=4,
help="Specify rank of output LoRA") help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument("--save_to", type=str, default=None, parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--model", type=str, default=None, parser.add_argument("--model", type=str, default=None,
help="LoRA model to resize at to new rank: ckpt or safetensors file") help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
args = parser.parse_args() args = parser.parse_args()
resize(args) resize(args)