From da48f74e7bce80c4a708a0328e10f00dc9fdbe0a Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Wed, 18 Jan 2023 23:00:16 -0800 Subject: [PATCH] Add new version model/VAE hash to training metadata --- library/train_util.py | 15 ++++++++++++++- train_network.py | 2 ++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index e3ff1a38..59bd2a03 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -11,6 +11,7 @@ import glob import math import os import random +import hashlib from tqdm import tqdm import torch @@ -753,9 +754,9 @@ def default(val, d): def model_hash(filename): + """Old model hash used by stable-diffusion-webui""" try: with open(filename, "rb") as file: - import hashlib m = hashlib.sha256() file.seek(0x100000) @@ -765,6 +766,18 @@ def model_hash(filename): return 'NOFILE' +def calculate_sha256(filename): + """New model hash used by stable-diffusion-webui""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + # flash attention forwards and backwards # https://arxiv.org/abs/2205.14135 diff --git a/train_network.py b/train_network.py index 098145f2..c759e66b 100644 --- a/train_network.py +++ b/train_network.py @@ -254,6 +254,7 @@ def train(args): sd_model_name = args.pretrained_model_name_or_path if os.path.exists(sd_model_name): metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) + metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) sd_model_name = os.path.basename(sd_model_name) metadata["ss_sd_model_name"] = sd_model_name @@ -261,6 +262,7 @@ def train(args): vae_name = args.vae if os.path.exists(vae_name): metadata["ss_vae_hash"] = train_util.model_hash(vae_name) + metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) vae_name = os.path.basename(vae_name) metadata["ss_vae_name"] = vae_name