diff --git a/library/train_util.py b/library/train_util.py index aa65dc3c..94175b98 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -11,6 +11,7 @@ import glob import math import os import random +import hashlib from tqdm import tqdm import torch @@ -79,6 +80,8 @@ class BaseDataset(torch.utils.data.Dataset): self.debug_dataset = debug_dataset self.random_crop = random_crop self.token_padding_disabled = False + self.dataset_dirs = {} + self.reg_dataset_dirs = {} self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 @@ -523,6 +526,7 @@ class DreamBoothDataset(BaseDataset): for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, n_repeats, caption, False, img_path) self.register_image(info) + self.dataset_dirs[dir] = {"n_repeats": n_repeats, "img_count": len(img_paths)} print(f"{num_train_images} train images with repeating.") self.num_train_images = num_train_images @@ -539,6 +543,7 @@ class DreamBoothDataset(BaseDataset): for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, n_repeats, caption, True, img_path) reg_infos.append(info) + self.reg_dataset_dirs[dir] = {"n_repeats": n_repeats, "img_count": len(img_paths)} print(f"{num_reg_images} reg images.") if num_train_images < num_reg_images: @@ -749,9 +754,9 @@ def default(val, d): def model_hash(filename): + """Old model hash used by stable-diffusion-webui""" try: with open(filename, "rb") as file: - import hashlib m = hashlib.sha256() file.seek(0x100000) @@ -761,6 +766,18 @@ def model_hash(filename): return 'NOFILE' +def calculate_sha256(filename): + """New model hash used by stable-diffusion-webui""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + # flash attention forwards and backwards # https://arxiv.org/abs/2205.14135 diff --git a/train_network.py b/train_network.py index b2c7b579..73370ee2 100644 --- a/train_network.py +++ b/train_network.py @@ -3,6 +3,9 @@ import argparse import gc import math import os +import random +import time +import json from tqdm import tqdm import torch @@ -19,6 +22,8 @@ def collate_fn(examples): def train(args): + session_id = random.randint(0, 2**32) + training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) @@ -206,10 +211,13 @@ def train(args): print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") metadata = { + "ss_session_id": session_id, # random integer indicating which group of epochs the model came from + "ss_training_started_at": training_started_at, # unix timestamp + "ss_output_name": args.output_name, "ss_learning_rate": args.learning_rate, "ss_text_encoder_lr": args.text_encoder_lr, "ss_unet_lr": args.unet_lr, - "ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data + "ss_num_train_images": train_dataset.num_train_images, # includes repeating "ss_num_reg_images": train_dataset.num_reg_images, "ss_num_batches_per_epoch": len(train_dataloader), "ss_num_epochs": num_train_epochs, @@ -235,7 +243,10 @@ def train(args): "ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT "ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset "ss_max_bucket_reso": args.max_bucket_reso, - "ss_seed": args.seed + "ss_seed": args.seed, + "ss_keep_tokens": args.keep_tokens, + "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs), + "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs), } # uncomment if another network is added @@ -246,6 +257,7 @@ def train(args): sd_model_name = args.pretrained_model_name_or_path if os.path.exists(sd_model_name): metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) + metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) sd_model_name = os.path.basename(sd_model_name) metadata["ss_sd_model_name"] = sd_model_name @@ -253,6 +265,7 @@ def train(args): vae_name = args.vae if os.path.exists(vae_name): metadata["ss_vae_hash"] = train_util.model_hash(vae_name) + metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) vae_name = os.path.basename(vae_name) metadata["ss_vae_name"] = vae_name