From 4ad8e75291ce77974b6441c9710a459cc95ee802 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 10 Mar 2023 21:10:22 +0900 Subject: [PATCH] fix to work with dim>320 --- networks/resize_lora.py | 1 - networks/svd_merge_lora.py | 9 +++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 1a8110c4..dfacd666 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -1,6 +1,5 @@ # Convert LoRA to different rank approximation (should only be used to go to lower rank) # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py -# Thanks to cloneofsimo and kohya import argparse import torch diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index c8e39b80..3a03b0d5 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -23,16 +23,16 @@ def load_state_dict(file_name, dtype): return sd -def save_to_file(file_name, model, state_dict, dtype): +def save_to_file(file_name, state_dict, dtype): if dtype is not None: for key in list(state_dict.keys()): if type(state_dict[key]) == torch.Tensor: state_dict[key] = state_dict[key].to(dtype) if os.path.splitext(file_name)[1] == '.safetensors': - save_file(model, file_name) + save_file(state_dict, file_name) else: - torch.save(model, file_name) + torch.save(state_dict, file_name) def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): @@ -105,6 +105,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty mat = mat.squeeze() module_new_rank = new_conv_rank if conv2d_3x3 else new_rank + module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim U, S, Vh = torch.linalg.svd(mat) @@ -156,7 +157,7 @@ def merge(args): state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype) print(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype) + save_to_file(args.save_to, state_dict, save_dtype) if __name__ == '__main__':