mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Precalculate .safetensors model hashes after training
This commit is contained in:
@@ -7,6 +7,8 @@ import math
|
||||
import os
|
||||
import torch
|
||||
|
||||
from library import train_util
|
||||
|
||||
|
||||
class LoRAModule(torch.nn.Module):
|
||||
"""
|
||||
@@ -221,6 +223,14 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
if os.path.splitext(file)[1] == '.safetensors':
|
||||
from safetensors.torch import save_file
|
||||
|
||||
# Precalculate model hashes to save time on indexing
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
Reference in New Issue
Block a user