mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add new version model/VAE hash to training metadata
This commit is contained in:
@@ -11,6 +11,7 @@ import glob
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import hashlib
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@@ -753,9 +754,9 @@ def default(val, d):
|
|||||||
|
|
||||||
|
|
||||||
def model_hash(filename):
|
def model_hash(filename):
|
||||||
|
"""Old model hash used by stable-diffusion-webui"""
|
||||||
try:
|
try:
|
||||||
with open(filename, "rb") as file:
|
with open(filename, "rb") as file:
|
||||||
import hashlib
|
|
||||||
m = hashlib.sha256()
|
m = hashlib.sha256()
|
||||||
|
|
||||||
file.seek(0x100000)
|
file.seek(0x100000)
|
||||||
@@ -765,6 +766,18 @@ def model_hash(filename):
|
|||||||
return 'NOFILE'
|
return 'NOFILE'
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_sha256(filename):
|
||||||
|
"""New model hash used by stable-diffusion-webui"""
|
||||||
|
hash_sha256 = hashlib.sha256()
|
||||||
|
blksize = 1024 * 1024
|
||||||
|
|
||||||
|
with open(filename, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(blksize), b""):
|
||||||
|
hash_sha256.update(chunk)
|
||||||
|
|
||||||
|
return hash_sha256.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
# flash attention forwards and backwards
|
# flash attention forwards and backwards
|
||||||
|
|
||||||
# https://arxiv.org/abs/2205.14135
|
# https://arxiv.org/abs/2205.14135
|
||||||
|
|||||||
@@ -254,6 +254,7 @@ def train(args):
|
|||||||
sd_model_name = args.pretrained_model_name_or_path
|
sd_model_name = args.pretrained_model_name_or_path
|
||||||
if os.path.exists(sd_model_name):
|
if os.path.exists(sd_model_name):
|
||||||
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
|
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
|
||||||
|
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
|
||||||
sd_model_name = os.path.basename(sd_model_name)
|
sd_model_name = os.path.basename(sd_model_name)
|
||||||
metadata["ss_sd_model_name"] = sd_model_name
|
metadata["ss_sd_model_name"] = sd_model_name
|
||||||
|
|
||||||
@@ -261,6 +262,7 @@ def train(args):
|
|||||||
vae_name = args.vae
|
vae_name = args.vae
|
||||||
if os.path.exists(vae_name):
|
if os.path.exists(vae_name):
|
||||||
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
|
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
|
||||||
|
metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
|
||||||
vae_name = os.path.basename(vae_name)
|
vae_name = os.path.basename(vae_name)
|
||||||
metadata["ss_vae_name"] = vae_name
|
metadata["ss_vae_name"] = vae_name
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user