mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix LoRA metadata hash calculation bug in svd_merge_lora.py, sdxl_merge_lora.py, and resize_lora.py closes #1722
This commit is contained in:
@@ -35,12 +35,7 @@ def load_state_dict(file_name, dtype):
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
def save_to_file(file_name, model, metadata):
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(model, file_name, metadata=metadata)
|
||||
else:
|
||||
@@ -430,6 +425,12 @@ def merge(args):
|
||||
else:
|
||||
state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle)
|
||||
|
||||
# cast to save_dtype before calculating hashes
|
||||
for key in list(state_dict.keys()):
|
||||
value = state_dict[key]
|
||||
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
|
||||
state_dict[key] = value.to(save_dtype)
|
||||
|
||||
logger.info(f"calculating hashes and creating metadata...")
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
@@ -445,7 +446,7 @@ def merge(args):
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||
save_to_file(args.save_to, state_dict, metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
Reference in New Issue
Block a user