Precalculate .safetensors model hashes after training

This commit is contained in:
space-nuko
2023-01-23 17:21:04 -08:00
parent 93df55d597
commit f7fbdc4b2a
2 changed files with 55 additions and 0 deletions

View File

@@ -7,6 +7,8 @@ import math
import os
import torch
from library import train_util
class LoRAModule(torch.nn.Module):
"""
@@ -221,6 +223,14 @@ class LoRANetwork(torch.nn.Module):
if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import save_file
# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)