Merge pull request #62 from kohya-ss/dev

Add training metadata to saved models. Thanks to space-nuko!
This commit is contained in:
Kohya S
2023-01-12 21:55:50 +09:00
committed by GitHub
5 changed files with 91 additions and 7 deletions

View File

@@ -46,11 +46,13 @@ VGG(
)
"""
import json
from typing import List, Optional, Union
import glob
import importlib
import inspect
import time
import zipfile
from diffusers.utils import deprecate
from diffusers.configuration_utils import FrozenDict
import argparse
@@ -1972,6 +1974,14 @@ def main(args):
if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if os.path.splitext(network_weight)[1] == '.safetensors':
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network.load_weights(network_weight)
network.apply_to(text_encoder, unet)

View File

@@ -747,6 +747,20 @@ def exists(val):
def default(val, d):
return val if exists(val) else d
def model_hash(filename):
try:
with open(filename, "rb") as file:
import hashlib
m = hashlib.sha256()
file.seek(0x100000)
m.update(file.read(0x10000))
return m.hexdigest()[0:8]
except FileNotFoundError:
return 'NOFILE'
# flash attention forwards and backwards
# https://arxiv.org/abs/2205.14135

View File

@@ -135,7 +135,7 @@ def svd(args):
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
lora_network_o.save_weights(args.save_to, save_dtype)
lora_network_o.save_weights(args.save_to, save_dtype, {})
print(f"LoRA weights are saved to: {args.save_to}")

View File

@@ -92,7 +92,7 @@ class LoRANetwork(torch.nn.Module):
def load_weights(self, file):
if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import load_file
from safetensors.torch import load_file, safe_open
self.weights_sd = load_file(file)
else:
self.weights_sd = torch.load(file, map_location='cpu')
@@ -174,7 +174,10 @@ class LoRANetwork(torch.nn.Module):
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype):
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
@@ -185,6 +188,6 @@ class LoRANetwork(torch.nn.Module):
if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import save_file
save_file(state_dict, file)
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)

View File

@@ -194,9 +194,62 @@ def train(args):
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
metadata = {
"ss_learning_rate": args.learning_rate,
"ss_text_encoder_lr": args.text_encoder_lr,
"ss_unet_lr": args.unet_lr,
"ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data
"ss_num_reg_images": train_dataset.num_reg_images,
"ss_num_batches_per_epoch": len(train_dataloader),
"ss_num_epochs": num_train_epochs,
"ss_batch_size_per_device": args.train_batch_size,
"ss_total_batch_size": total_batch_size,
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
"ss_max_train_steps": args.max_train_steps,
"ss_lr_warmup_steps": args.lr_warmup_steps,
"ss_lr_scheduler": args.lr_scheduler,
"ss_network_module": args.network_module,
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
"ss_mixed_precision": args.mixed_precision,
"ss_full_fp16": bool(args.full_fp16),
"ss_v2": bool(args.v2),
"ss_resolution": args.resolution,
"ss_clip_skip": args.clip_skip,
"ss_max_token_length": args.max_token_length,
"ss_color_aug": bool(args.color_aug),
"ss_flip_aug": bool(args.flip_aug),
"ss_random_crop": bool(args.random_crop),
"ss_shuffle_caption": bool(args.shuffle_caption),
"ss_cache_latents": bool(args.cache_latents),
"ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT
"ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset
"ss_max_bucket_reso": args.max_bucket_reso,
"ss_seed": args.seed
}
# uncomment if another network is added
# for key, value in net_kwargs.items():
# metadata["ss_arg_" + key] = value
if args.pretrained_model_name_or_path is not None:
sd_model_name = args.pretrained_model_name_or_path
if os.path.exists(sd_model_name):
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
sd_model_name = os.path.basename(sd_model_name)
metadata["ss_sd_model_name"] = sd_model_name
if args.vae is not None:
vae_name = args.vae
if os.path.exists(vae_name):
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
vae_name = os.path.basename(vae_name)
metadata["ss_vae_name"] = vae_name
metadata = {k: str(v) for k, v in metadata.items()}
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
@@ -208,6 +261,7 @@ def train(args):
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
metadata["ss_epoch"] = str(epoch+1)
network.on_epoch_start(text_encoder, unet)
@@ -296,7 +350,7 @@ def train(args):
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
unwrap_model(network).save_weights(ckpt_file, save_dtype)
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
@@ -311,6 +365,8 @@ def train(args):
# end of epoch
metadata["ss_epoch"] = str(num_train_epochs)
is_main_process = accelerator.is_main_process
if is_main_process:
network = unwrap_model(network)
@@ -330,7 +386,7 @@ def train(args):
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model to {ckpt_file}")
network.save_weights(ckpt_file, save_dtype)
network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
print("model saved.")
@@ -341,6 +397,7 @@ if __name__ == '__main__':
train_util.add_dataset_arguments(parser, True, True)
train_util.add_training_arguments(parser, True)
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt")