From de1dde1a06b5a591a376383d441a1710a743f1b1 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 17 Jan 2023 16:28:35 -0800 Subject: [PATCH] More helpful metadata - dataset/reg image dirs - random session ID - keep_tokens - training date - output name --- library/train_util.py | 4 ++++ train_network.py | 15 +++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 57ebf1b0..e3ff1a38 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -79,6 +79,8 @@ class BaseDataset(torch.utils.data.Dataset): self.debug_dataset = debug_dataset self.random_crop = random_crop self.token_padding_disabled = False + self.dataset_dirs = {} + self.reg_dataset_dirs = {} self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 @@ -523,6 +525,7 @@ class DreamBoothDataset(BaseDataset): for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, n_repeats, caption, False, img_path) self.register_image(info) + self.dataset_dirs[dir] = {"n_repeats": n_repeats, "img_count": len(img_paths)} print(f"{num_train_images} train images with repeating.") self.num_train_images = num_train_images @@ -539,6 +542,7 @@ class DreamBoothDataset(BaseDataset): for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, n_repeats, caption, True, img_path) reg_infos.append(info) + self.reg_dataset_dirs[dir] = {"n_repeats": n_repeats, "img_count": len(img_paths)} print(f"{num_reg_images} reg images.") if num_train_images < num_reg_images: diff --git a/train_network.py b/train_network.py index c0a881ad..098145f2 100644 --- a/train_network.py +++ b/train_network.py @@ -3,6 +3,9 @@ import argparse import gc import math import os +import random +import time +import json from tqdm import tqdm import torch @@ -19,6 +22,8 @@ def collate_fn(examples): def train(args): + session_id = random.randint(0, 2**32) + training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) @@ -203,10 +208,13 @@ def train(args): print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") metadata = { + "ss_session_id": session_id, # random integer indicating which group of epochs the model came from + "ss_training_started_at": training_started_at, # unix timestamp + "ss_output_name": args.output_name, "ss_learning_rate": args.learning_rate, "ss_text_encoder_lr": args.text_encoder_lr, "ss_unet_lr": args.unet_lr, - "ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data + "ss_num_train_images": train_dataset.num_train_images, # includes repeating "ss_num_reg_images": train_dataset.num_reg_images, "ss_num_batches_per_epoch": len(train_dataloader), "ss_num_epochs": num_train_epochs, @@ -232,7 +240,10 @@ def train(args): "ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT "ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset "ss_max_bucket_reso": args.max_bucket_reso, - "ss_seed": args.seed + "ss_seed": args.seed, + "ss_keep_tokens": args.keep_tokens, + "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs), + "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs), } # uncomment if another network is added