diff --git a/library/train_util.py b/library/train_util.py index 7a0f794b..70af44c9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -747,6 +747,20 @@ def exists(val): def default(val, d): return val if exists(val) else d + +def model_hash(filename): + try: + with open(filename, "rb") as file: + import hashlib + m = hashlib.sha256() + + file.seek(0x100000) + m.update(file.read(0x10000)) + return m.hexdigest()[0:8] + except FileNotFoundError: + return 'NOFILE' + + # flash attention forwards and backwards # https://arxiv.org/abs/2205.14135 diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index c882e88f..0a4c3a00 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -135,7 +135,7 @@ def svd(args): if dir_name and not os.path.exists(dir_name): os.makedirs(dir_name, exist_ok=True) - lora_network_o.save_weights(args.save_to, save_dtype) + lora_network_o.save_weights(args.save_to, save_dtype, {}) print(f"LoRA weights are saved to: {args.save_to}") diff --git a/networks/lora.py b/networks/lora.py index 730a6376..98e8e4a4 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -6,6 +6,8 @@ import math import os import torch +import zipfile +import json class LoRAModule(torch.nn.Module): @@ -56,6 +58,7 @@ class LoRANetwork(torch.nn.Module): TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' + METADATA_FILENAME = "sd_scripts_metadata.json" def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None: super().__init__() @@ -92,7 +95,7 @@ class LoRANetwork(torch.nn.Module): def load_weights(self, file): if os.path.splitext(file)[1] == '.safetensors': - from safetensors.torch import load_file + from safetensors.torch import load_file, safe_open self.weights_sd = load_file(file) else: self.weights_sd = torch.load(file, map_location='cpu') @@ -174,7 +177,7 @@ class LoRANetwork(torch.nn.Module): def get_trainable_params(self): return self.parameters() - def save_weights(self, file, dtype): + def save_weights(self, file, dtype, metadata): state_dict = self.state_dict() if dtype is not None: @@ -185,6 +188,8 @@ class LoRANetwork(torch.nn.Module): if os.path.splitext(file)[1] == '.safetensors': from safetensors.torch import save_file - save_file(state_dict, file) + save_file(state_dict, file, metadata) else: torch.save(state_dict, file) + with zipfile.ZipFile(file, "w") as zipf: + zipf.writestr(LoRANetwork.METADATA_FILENAME, json.dumps(metadata)) diff --git a/train_network.py b/train_network.py index 9f292b97..c5593c46 100644 --- a/train_network.py +++ b/train_network.py @@ -194,9 +194,50 @@ def train(args): print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + metadata = { + "ss_learning_rate": args.learning_rate, + "ss_text_encoder_lr": args.text_encoder_lr, + "ss_unet_lr": args.unet_lr, + "ss_num_train_images": train_dataset.num_train_images, + "ss_num_reg_images": train_dataset.num_reg_images, + "ss_num_batches_per_epoch": len(train_dataloader), + "ss_num_epochs": num_train_epochs, + "ss_batch_size_per_device": args.train_batch_size, + "ss_total_batch_size": total_batch_size, + "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, + "ss_max_train_steps": args.max_train_steps, + "ss_lr_warmup_steps": args.lr_warmup_steps, + "ss_lr_scheduler": args.lr_scheduler, + "ss_network_module": args.network_module, + "ss_network_dim": 4 if args.network_dim is None else args.network_dim, + "ss_full_fp16": bool(args.full_fp16), + "ss_v2": bool(args.v2), + "ss_resolution": args.resolution, + "ss_clip_skip": args.clip_skip, + "ss_max_token_length": args.max_token_length, + "ss_color_aug": bool(args.color_aug), + "ss_flip_aug": bool(args.flip_aug), + "ss_random_crop": bool(args.random_crop), + "ss_shuffle_caption": bool(args.shuffle_caption), + "ss_cache_latents": bool(args.cache_latents), + "ss_enable_bucket": bool(args.enable_bucket), + "ss_min_bucket_reso": args.min_bucket_reso, + "ss_max_bucket_reso": args.max_bucket_reso, + "ss_seed": args.seed + } + + if args.pretrained_model_name_or_path is not None: + sd_model_name = args.pretrained_model_name_or_path + if os.path.exists(sd_model_name): + metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) + sd_model_name = os.path.basename(sd_model_name) + metadata["ss_sd_model_name"] = sd_model_name + + metadata = {k: str(v) for k, v in metadata.items()} + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 @@ -208,6 +249,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + metadata["ss_epoch"] = str(epoch+1) network.on_epoch_start(text_encoder, unet) @@ -296,7 +338,7 @@ def train(args): ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"saving checkpoint: {ckpt_file}") - unwrap_model(network).save_weights(ckpt_file, save_dtype) + unwrap_model(network).save_weights(ckpt_file, save_dtype, metadata) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as @@ -311,6 +353,8 @@ def train(args): # end of epoch + metadata["ss_epoch"] = str(num_train_epochs) + is_main_process = accelerator.is_main_process if is_main_process: network = unwrap_model(network) @@ -330,7 +374,7 @@ def train(args): ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"save trained model to {ckpt_file}") - network.save_weights(ckpt_file, save_dtype) + network.save_weights(ckpt_file, save_dtype, metadata) print("model saved.")