mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix to work with dim>320
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
||||||
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
||||||
# Thanks to cloneofsimo and kohya
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -23,16 +23,16 @@ def load_state_dict(file_name, dtype):
|
|||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
|
||||||
def save_to_file(file_name, model, state_dict, dtype):
|
def save_to_file(file_name, state_dict, dtype):
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
for key in list(state_dict.keys()):
|
for key in list(state_dict.keys()):
|
||||||
if type(state_dict[key]) == torch.Tensor:
|
if type(state_dict[key]) == torch.Tensor:
|
||||||
state_dict[key] = state_dict[key].to(dtype)
|
state_dict[key] = state_dict[key].to(dtype)
|
||||||
|
|
||||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||||
save_file(model, file_name)
|
save_file(state_dict, file_name)
|
||||||
else:
|
else:
|
||||||
torch.save(model, file_name)
|
torch.save(state_dict, file_name)
|
||||||
|
|
||||||
|
|
||||||
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
||||||
@@ -105,6 +105,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|||||||
mat = mat.squeeze()
|
mat = mat.squeeze()
|
||||||
|
|
||||||
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
|
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
|
||||||
|
module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||||
|
|
||||||
U, S, Vh = torch.linalg.svd(mat)
|
U, S, Vh = torch.linalg.svd(mat)
|
||||||
|
|
||||||
@@ -156,7 +157,7 @@ def merge(args):
|
|||||||
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
|
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
|
||||||
|
|
||||||
print(f"saving model to: {args.save_to}")
|
print(f"saving model to: {args.save_to}")
|
||||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
save_to_file(args.save_to, state_dict, save_dtype)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user