fix to work with dim>320

This commit is contained in:
Kohya S
2023-03-10 21:10:22 +09:00
parent e355b5e1d3
commit 4ad8e75291
2 changed files with 5 additions and 5 deletions

View File

@@ -23,16 +23,16 @@ def load_state_dict(file_name, dtype):
return sd
def save_to_file(file_name, model, state_dict, dtype):
def save_to_file(file_name, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
save_file(state_dict, file_name)
else:
torch.save(model, file_name)
torch.save(state_dict, file_name)
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
@@ -105,6 +105,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
mat = mat.squeeze()
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
U, S, Vh = torch.linalg.svd(mat)
@@ -156,7 +157,7 @@ def merge(args):
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
save_to_file(args.save_to, state_dict, save_dtype)
if __name__ == '__main__':