mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge pull request #77 from space-nuko/ss-extra-metadata
More helpful metadata
This commit is contained in:
@@ -11,6 +11,7 @@ import glob
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import hashlib
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@@ -79,6 +80,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
self.debug_dataset = debug_dataset
|
self.debug_dataset = debug_dataset
|
||||||
self.random_crop = random_crop
|
self.random_crop = random_crop
|
||||||
self.token_padding_disabled = False
|
self.token_padding_disabled = False
|
||||||
|
self.dataset_dirs = {}
|
||||||
|
self.reg_dataset_dirs = {}
|
||||||
|
|
||||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||||
|
|
||||||
@@ -523,6 +526,7 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
for img_path, caption in zip(img_paths, captions):
|
for img_path, caption in zip(img_paths, captions):
|
||||||
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
||||||
self.register_image(info)
|
self.register_image(info)
|
||||||
|
self.dataset_dirs[dir] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||||
print(f"{num_train_images} train images with repeating.")
|
print(f"{num_train_images} train images with repeating.")
|
||||||
self.num_train_images = num_train_images
|
self.num_train_images = num_train_images
|
||||||
|
|
||||||
@@ -539,6 +543,7 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
for img_path, caption in zip(img_paths, captions):
|
for img_path, caption in zip(img_paths, captions):
|
||||||
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
||||||
reg_infos.append(info)
|
reg_infos.append(info)
|
||||||
|
self.reg_dataset_dirs[dir] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||||
|
|
||||||
print(f"{num_reg_images} reg images.")
|
print(f"{num_reg_images} reg images.")
|
||||||
if num_train_images < num_reg_images:
|
if num_train_images < num_reg_images:
|
||||||
@@ -749,9 +754,9 @@ def default(val, d):
|
|||||||
|
|
||||||
|
|
||||||
def model_hash(filename):
|
def model_hash(filename):
|
||||||
|
"""Old model hash used by stable-diffusion-webui"""
|
||||||
try:
|
try:
|
||||||
with open(filename, "rb") as file:
|
with open(filename, "rb") as file:
|
||||||
import hashlib
|
|
||||||
m = hashlib.sha256()
|
m = hashlib.sha256()
|
||||||
|
|
||||||
file.seek(0x100000)
|
file.seek(0x100000)
|
||||||
@@ -761,6 +766,18 @@ def model_hash(filename):
|
|||||||
return 'NOFILE'
|
return 'NOFILE'
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_sha256(filename):
|
||||||
|
"""New model hash used by stable-diffusion-webui"""
|
||||||
|
hash_sha256 = hashlib.sha256()
|
||||||
|
blksize = 1024 * 1024
|
||||||
|
|
||||||
|
with open(filename, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(blksize), b""):
|
||||||
|
hash_sha256.update(chunk)
|
||||||
|
|
||||||
|
return hash_sha256.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
# flash attention forwards and backwards
|
# flash attention forwards and backwards
|
||||||
|
|
||||||
# https://arxiv.org/abs/2205.14135
|
# https://arxiv.org/abs/2205.14135
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ import argparse
|
|||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@@ -19,6 +22,8 @@ def collate_fn(examples):
|
|||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
|
session_id = random.randint(0, 2**32)
|
||||||
|
training_started_at = time.time()
|
||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
|
|
||||||
@@ -206,10 +211,13 @@ def train(args):
|
|||||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
|
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
||||||
|
"ss_training_started_at": training_started_at, # unix timestamp
|
||||||
|
"ss_output_name": args.output_name,
|
||||||
"ss_learning_rate": args.learning_rate,
|
"ss_learning_rate": args.learning_rate,
|
||||||
"ss_text_encoder_lr": args.text_encoder_lr,
|
"ss_text_encoder_lr": args.text_encoder_lr,
|
||||||
"ss_unet_lr": args.unet_lr,
|
"ss_unet_lr": args.unet_lr,
|
||||||
"ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data
|
"ss_num_train_images": train_dataset.num_train_images, # includes repeating
|
||||||
"ss_num_reg_images": train_dataset.num_reg_images,
|
"ss_num_reg_images": train_dataset.num_reg_images,
|
||||||
"ss_num_batches_per_epoch": len(train_dataloader),
|
"ss_num_batches_per_epoch": len(train_dataloader),
|
||||||
"ss_num_epochs": num_train_epochs,
|
"ss_num_epochs": num_train_epochs,
|
||||||
@@ -235,7 +243,10 @@ def train(args):
|
|||||||
"ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT
|
"ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT
|
||||||
"ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset
|
"ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset
|
||||||
"ss_max_bucket_reso": args.max_bucket_reso,
|
"ss_max_bucket_reso": args.max_bucket_reso,
|
||||||
"ss_seed": args.seed
|
"ss_seed": args.seed,
|
||||||
|
"ss_keep_tokens": args.keep_tokens,
|
||||||
|
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs),
|
||||||
|
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs),
|
||||||
}
|
}
|
||||||
|
|
||||||
# uncomment if another network is added
|
# uncomment if another network is added
|
||||||
@@ -246,6 +257,7 @@ def train(args):
|
|||||||
sd_model_name = args.pretrained_model_name_or_path
|
sd_model_name = args.pretrained_model_name_or_path
|
||||||
if os.path.exists(sd_model_name):
|
if os.path.exists(sd_model_name):
|
||||||
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
|
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
|
||||||
|
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
|
||||||
sd_model_name = os.path.basename(sd_model_name)
|
sd_model_name = os.path.basename(sd_model_name)
|
||||||
metadata["ss_sd_model_name"] = sd_model_name
|
metadata["ss_sd_model_name"] = sd_model_name
|
||||||
|
|
||||||
@@ -253,6 +265,7 @@ def train(args):
|
|||||||
vae_name = args.vae
|
vae_name = args.vae
|
||||||
if os.path.exists(vae_name):
|
if os.path.exists(vae_name):
|
||||||
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
|
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
|
||||||
|
metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
|
||||||
vae_name = os.path.basename(vae_name)
|
vae_name = os.path.basename(vae_name)
|
||||||
metadata["ss_vae_name"] = vae_name
|
metadata["ss_vae_name"] = vae_name
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user