diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 7b740634..03fc545e 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -219,8 +219,8 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn for key, value in tqdm(lora_sd.items()): weight_name = None if 'lora_down' in key: - block_down_name = key.split(".")[0] - weight_name = key.split(".")[-1] + block_down_name = key.rsplit('.lora_down', 1)[0] + weight_name = key.rsplit(".", 1)[-1] lora_down_weight = value else: continue @@ -283,7 +283,10 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn def resize(args): + if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')): + raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.") + def str_to_dtype(p): if p == 'float': return torch.float