More helpful metadata

- dataset/reg image dirs
- random session ID
- keep_tokens
- training date
- output name
This commit is contained in:
space-nuko
2023-01-17 16:28:35 -08:00
parent f2f2ce0d7d
commit de1dde1a06
2 changed files with 17 additions and 2 deletions

View File

@@ -79,6 +79,8 @@ class BaseDataset(torch.utils.data.Dataset):
self.debug_dataset = debug_dataset self.debug_dataset = debug_dataset
self.random_crop = random_crop self.random_crop = random_crop
self.token_padding_disabled = False self.token_padding_disabled = False
self.dataset_dirs = {}
self.reg_dataset_dirs = {}
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
@@ -523,6 +525,7 @@ class DreamBoothDataset(BaseDataset):
for img_path, caption in zip(img_paths, captions): for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, False, img_path) info = ImageInfo(img_path, n_repeats, caption, False, img_path)
self.register_image(info) self.register_image(info)
self.dataset_dirs[dir] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_train_images} train images with repeating.") print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images self.num_train_images = num_train_images
@@ -539,6 +542,7 @@ class DreamBoothDataset(BaseDataset):
for img_path, caption in zip(img_paths, captions): for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, True, img_path) info = ImageInfo(img_path, n_repeats, caption, True, img_path)
reg_infos.append(info) reg_infos.append(info)
self.reg_dataset_dirs[dir] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_reg_images} reg images.") print(f"{num_reg_images} reg images.")
if num_train_images < num_reg_images: if num_train_images < num_reg_images:

View File

@@ -3,6 +3,9 @@ import argparse
import gc import gc
import math import math
import os import os
import random
import time
import json
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -19,6 +22,8 @@ def collate_fn(examples):
def train(args): def train(args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
train_util.verify_training_args(args) train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True) train_util.prepare_dataset_args(args, True)
@@ -203,10 +208,13 @@ def train(args):
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
metadata = { metadata = {
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
"ss_training_started_at": training_started_at, # unix timestamp
"ss_output_name": args.output_name,
"ss_learning_rate": args.learning_rate, "ss_learning_rate": args.learning_rate,
"ss_text_encoder_lr": args.text_encoder_lr, "ss_text_encoder_lr": args.text_encoder_lr,
"ss_unet_lr": args.unet_lr, "ss_unet_lr": args.unet_lr,
"ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data "ss_num_train_images": train_dataset.num_train_images, # includes repeating
"ss_num_reg_images": train_dataset.num_reg_images, "ss_num_reg_images": train_dataset.num_reg_images,
"ss_num_batches_per_epoch": len(train_dataloader), "ss_num_batches_per_epoch": len(train_dataloader),
"ss_num_epochs": num_train_epochs, "ss_num_epochs": num_train_epochs,
@@ -232,7 +240,10 @@ def train(args):
"ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT "ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT
"ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset "ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset
"ss_max_bucket_reso": args.max_bucket_reso, "ss_max_bucket_reso": args.max_bucket_reso,
"ss_seed": args.seed "ss_seed": args.seed,
"ss_keep_tokens": args.keep_tokens,
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs),
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs),
} }
# uncomment if another network is added # uncomment if another network is added