Merge pull request #77 from space-nuko/ss-extra-metadata

More helpful metadata
This commit is contained in:
Kohya S
2023-01-21 12:18:23 +09:00
committed by GitHub
2 changed files with 33 additions and 3 deletions

View File

@@ -11,6 +11,7 @@ import glob
import math
import os
import random
import hashlib
from tqdm import tqdm
import torch
@@ -79,6 +80,8 @@ class BaseDataset(torch.utils.data.Dataset):
self.debug_dataset = debug_dataset
self.random_crop = random_crop
self.token_padding_disabled = False
self.dataset_dirs = {}
self.reg_dataset_dirs = {}
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
@@ -523,6 +526,7 @@ class DreamBoothDataset(BaseDataset):
for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
self.register_image(info)
self.dataset_dirs[dir] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images
@@ -539,6 +543,7 @@ class DreamBoothDataset(BaseDataset):
for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
reg_infos.append(info)
self.reg_dataset_dirs[dir] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_reg_images} reg images.")
if num_train_images < num_reg_images:
@@ -749,9 +754,9 @@ def default(val, d):
def model_hash(filename):
"""Old model hash used by stable-diffusion-webui"""
try:
with open(filename, "rb") as file:
import hashlib
m = hashlib.sha256()
file.seek(0x100000)
@@ -761,6 +766,18 @@ def model_hash(filename):
return 'NOFILE'
def calculate_sha256(filename):
"""New model hash used by stable-diffusion-webui"""
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
# flash attention forwards and backwards
# https://arxiv.org/abs/2205.14135