Merge pull request #77 from space-nuko/ss-extra-metadata

More helpful metadata
This commit is contained in:
Kohya S
2023-01-21 12:18:23 +09:00
committed by GitHub
2 changed files with 33 additions and 3 deletions

View File

@@ -11,6 +11,7 @@ import glob
import math import math
import os import os
import random import random
import hashlib
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -79,6 +80,8 @@ class BaseDataset(torch.utils.data.Dataset):
self.debug_dataset = debug_dataset self.debug_dataset = debug_dataset
self.random_crop = random_crop self.random_crop = random_crop
self.token_padding_disabled = False self.token_padding_disabled = False
self.dataset_dirs = {}
self.reg_dataset_dirs = {}
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
@@ -523,6 +526,7 @@ class DreamBoothDataset(BaseDataset):
for img_path, caption in zip(img_paths, captions): for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, False, img_path) info = ImageInfo(img_path, n_repeats, caption, False, img_path)
self.register_image(info) self.register_image(info)
self.dataset_dirs[dir] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_train_images} train images with repeating.") print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images self.num_train_images = num_train_images
@@ -539,6 +543,7 @@ class DreamBoothDataset(BaseDataset):
for img_path, caption in zip(img_paths, captions): for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, True, img_path) info = ImageInfo(img_path, n_repeats, caption, True, img_path)
reg_infos.append(info) reg_infos.append(info)
self.reg_dataset_dirs[dir] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_reg_images} reg images.") print(f"{num_reg_images} reg images.")
if num_train_images < num_reg_images: if num_train_images < num_reg_images:
@@ -749,9 +754,9 @@ def default(val, d):
def model_hash(filename): def model_hash(filename):
"""Old model hash used by stable-diffusion-webui"""
try: try:
with open(filename, "rb") as file: with open(filename, "rb") as file:
import hashlib
m = hashlib.sha256() m = hashlib.sha256()
file.seek(0x100000) file.seek(0x100000)
@@ -761,6 +766,18 @@ def model_hash(filename):
return 'NOFILE' return 'NOFILE'
def calculate_sha256(filename):
"""New model hash used by stable-diffusion-webui"""
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
# flash attention forwards and backwards # flash attention forwards and backwards
# https://arxiv.org/abs/2205.14135 # https://arxiv.org/abs/2205.14135

View File

@@ -3,6 +3,9 @@ import argparse
import gc import gc
import math import math
import os import os
import random
import time
import json
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -19,6 +22,8 @@ def collate_fn(examples):
def train(args): def train(args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
train_util.verify_training_args(args) train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True) train_util.prepare_dataset_args(args, True)
@@ -206,10 +211,13 @@ def train(args):
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
metadata = { metadata = {
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
"ss_training_started_at": training_started_at, # unix timestamp
"ss_output_name": args.output_name,
"ss_learning_rate": args.learning_rate, "ss_learning_rate": args.learning_rate,
"ss_text_encoder_lr": args.text_encoder_lr, "ss_text_encoder_lr": args.text_encoder_lr,
"ss_unet_lr": args.unet_lr, "ss_unet_lr": args.unet_lr,
"ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data "ss_num_train_images": train_dataset.num_train_images, # includes repeating
"ss_num_reg_images": train_dataset.num_reg_images, "ss_num_reg_images": train_dataset.num_reg_images,
"ss_num_batches_per_epoch": len(train_dataloader), "ss_num_batches_per_epoch": len(train_dataloader),
"ss_num_epochs": num_train_epochs, "ss_num_epochs": num_train_epochs,
@@ -235,7 +243,10 @@ def train(args):
"ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT "ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT
"ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset "ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset
"ss_max_bucket_reso": args.max_bucket_reso, "ss_max_bucket_reso": args.max_bucket_reso,
"ss_seed": args.seed "ss_seed": args.seed,
"ss_keep_tokens": args.keep_tokens,
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs),
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs),
} }
# uncomment if another network is added # uncomment if another network is added
@@ -246,6 +257,7 @@ def train(args):
sd_model_name = args.pretrained_model_name_or_path sd_model_name = args.pretrained_model_name_or_path
if os.path.exists(sd_model_name): if os.path.exists(sd_model_name):
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
sd_model_name = os.path.basename(sd_model_name) sd_model_name = os.path.basename(sd_model_name)
metadata["ss_sd_model_name"] = sd_model_name metadata["ss_sd_model_name"] = sd_model_name
@@ -253,6 +265,7 @@ def train(args):
vae_name = args.vae vae_name = args.vae
if os.path.exists(vae_name): if os.path.exists(vae_name):
metadata["ss_vae_hash"] = train_util.model_hash(vae_name) metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
vae_name = os.path.basename(vae_name) vae_name = os.path.basename(vae_name)
metadata["ss_vae_name"] = vae_name metadata["ss_vae_name"] = vae_name