Merge pull request #77 from space-nuko/ss-extra-metadata

More helpful metadata
This commit is contained in:
Kohya S
2023-01-21 12:18:23 +09:00
committed by GitHub
2 changed files with 33 additions and 3 deletions

View File

@@ -3,6 +3,9 @@ import argparse
import gc
import math
import os
import random
import time
import json
from tqdm import tqdm
import torch
@@ -19,6 +22,8 @@ def collate_fn(examples):
def train(args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
@@ -206,10 +211,13 @@ def train(args):
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
metadata = {
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
"ss_training_started_at": training_started_at, # unix timestamp
"ss_output_name": args.output_name,
"ss_learning_rate": args.learning_rate,
"ss_text_encoder_lr": args.text_encoder_lr,
"ss_unet_lr": args.unet_lr,
"ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data
"ss_num_train_images": train_dataset.num_train_images, # includes repeating
"ss_num_reg_images": train_dataset.num_reg_images,
"ss_num_batches_per_epoch": len(train_dataloader),
"ss_num_epochs": num_train_epochs,
@@ -235,7 +243,10 @@ def train(args):
"ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT
"ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset
"ss_max_bucket_reso": args.max_bucket_reso,
"ss_seed": args.seed
"ss_seed": args.seed,
"ss_keep_tokens": args.keep_tokens,
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs),
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs),
}
# uncomment if another network is added
@@ -246,6 +257,7 @@ def train(args):
sd_model_name = args.pretrained_model_name_or_path
if os.path.exists(sd_model_name):
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
sd_model_name = os.path.basename(sd_model_name)
metadata["ss_sd_model_name"] = sd_model_name
@@ -253,6 +265,7 @@ def train(args):
vae_name = args.vae
if os.path.exists(vae_name):
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
vae_name = os.path.basename(vae_name)
metadata["ss_vae_name"] = vae_name