mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 14:45:19 +00:00
Merge pull request #102 from space-nuko/precalculate-hashes
Precalculate .safetensors model hashes after training
This commit is contained in:
@@ -12,6 +12,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import hashlib
|
import hashlib
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@@ -25,6 +26,7 @@ from PIL import Image
|
|||||||
import cv2
|
import cv2
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
import safetensors.torch
|
||||||
|
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
|
|
||||||
@@ -790,6 +792,49 @@ def calculate_sha256(filename):
|
|||||||
return hash_sha256.hexdigest()
|
return hash_sha256.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def precalculate_safetensors_hashes(tensors, metadata):
|
||||||
|
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
||||||
|
save time on indexing the model later."""
|
||||||
|
|
||||||
|
# Because writing user metadata to the file can change the result of
|
||||||
|
# sd_models.model_hash(), only retain the training metadata for purposes of
|
||||||
|
# calculating the hash, as they are meant to be immutable
|
||||||
|
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
|
||||||
|
|
||||||
|
bytes = safetensors.torch.save(tensors, metadata)
|
||||||
|
b = BytesIO(bytes)
|
||||||
|
|
||||||
|
model_hash = addnet_hash_safetensors(b)
|
||||||
|
legacy_hash = addnet_hash_legacy(b)
|
||||||
|
return model_hash, legacy_hash
|
||||||
|
|
||||||
|
|
||||||
|
def addnet_hash_legacy(b):
|
||||||
|
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
||||||
|
m = hashlib.sha256()
|
||||||
|
|
||||||
|
b.seek(0x100000)
|
||||||
|
m.update(b.read(0x10000))
|
||||||
|
return m.hexdigest()[0:8]
|
||||||
|
|
||||||
|
|
||||||
|
def addnet_hash_safetensors(b):
|
||||||
|
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
||||||
|
hash_sha256 = hashlib.sha256()
|
||||||
|
blksize = 1024 * 1024
|
||||||
|
|
||||||
|
b.seek(0)
|
||||||
|
header = b.read(8)
|
||||||
|
n = int.from_bytes(header, "little")
|
||||||
|
|
||||||
|
offset = n + 8
|
||||||
|
b.seek(offset)
|
||||||
|
for chunk in iter(lambda: b.read(blksize), b""):
|
||||||
|
hash_sha256.update(chunk)
|
||||||
|
|
||||||
|
return hash_sha256.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
# flash attention forwards and backwards
|
# flash attention forwards and backwards
|
||||||
|
|
||||||
# https://arxiv.org/abs/2205.14135
|
# https://arxiv.org/abs/2205.14135
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import math
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from library import train_util
|
||||||
|
|
||||||
|
|
||||||
class LoRAModule(torch.nn.Module):
|
class LoRAModule(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -221,6 +223,14 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
if os.path.splitext(file)[1] == '.safetensors':
|
if os.path.splitext(file)[1] == '.safetensors':
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
# Precalculate model hashes to save time on indexing
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||||
|
metadata["sshs_model_hash"] = model_hash
|
||||||
|
metadata["sshs_legacy_hash"] = legacy_hash
|
||||||
|
|
||||||
save_file(state_dict, file, metadata)
|
save_file(state_dict, file, metadata)
|
||||||
else:
|
else:
|
||||||
torch.save(state_dict, file)
|
torch.save(state_dict, file)
|
||||||
|
|||||||
Reference in New Issue
Block a user