diff --git a/fine_tune.py b/fine_tune.py index 02f665bd..8e615203 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -200,6 +200,8 @@ def train(args): # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps diff --git a/library/train_util.py b/library/train_util.py index 0fdbadc1..f967c5f8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -12,6 +12,7 @@ import math import os import random import hashlib +from io import BytesIO from tqdm import tqdm import torch @@ -25,6 +26,7 @@ from PIL import Image import cv2 from einops import rearrange from torch import einsum +import safetensors.torch import library.model_util as model_util @@ -85,6 +87,7 @@ class BaseDataset(torch.utils.data.Dataset): self.enable_bucket = False self.min_bucket_reso = None self.max_bucket_reso = None + self.bucket_info = None self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 @@ -217,11 +220,17 @@ class BaseDataset(torch.utils.data.Dataset): self.buckets[bucket_index].append(image_info.image_key) if self.enable_bucket: + self.bucket_info = {"buckets": {}} print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)): + self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)} print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}") + img_ar_errors = np.array(img_ar_errors) - print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}") + mean_img_ar_error = np.mean(np.abs(img_ar_errors)) + self.bucket_info["mean_img_ar_error"] = mean_img_ar_error + print(f"mean ar error (without repeats): {mean_img_ar_error}") + # 参照用indexを作る self.buckets_indices: list(BucketBatchIndex) = [] @@ -790,6 +799,49 @@ def calculate_sha256(filename): return hash_sha256.hexdigest() +def precalculate_safetensors_hashes(tensors, metadata): + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" + + # Because writing user metadata to the file can change the result of + # sd_models.model_hash(), only retain the training metadata for purposes of + # calculating the hash, as they are meant to be immutable + metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} + + bytes = safetensors.torch.save(tensors, metadata) + b = BytesIO(bytes) + + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + return model_hash, legacy_hash + + +def addnet_hash_legacy(b): + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() + + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] + + +def addnet_hash_safetensors(b): + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + # flash attention forwards and backwards # https://arxiv.org/abs/2205.14135 @@ -1057,6 +1109,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") + parser.add_argument("--save_n_epoch_ratio", type=int, default=None, + help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)") parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)") diff --git a/networks/lora.py b/networks/lora.py index 9243f1e1..174feda5 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -7,6 +7,8 @@ import math import os import torch +from library import train_util + class LoRAModule(torch.nn.Module): """ @@ -31,7 +33,7 @@ class LoRAModule(torch.nn.Module): self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) if type(alpha) == torch.Tensor: - alpha = alpha.detach().numpy() + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える @@ -221,6 +223,14 @@ class LoRANetwork(torch.nn.Module): if os.path.splitext(file)[1] == '.safetensors': from safetensors.torch import save_file + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + save_file(state_dict, file, metadata) else: torch.save(state_dict, file) diff --git a/train_db.py b/train_db.py index 8ac503ea..fe6fd4e6 100644 --- a/train_db.py +++ b/train_db.py @@ -176,6 +176,8 @@ def train(args): # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps diff --git a/train_network.py b/train_network.py index d60ae9a0..8a8acc7d 100644 --- a/train_network.py +++ b/train_network.py @@ -212,6 +212,8 @@ def train(args): # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -264,6 +266,7 @@ def train(args): "ss_keep_tokens": args.keep_tokens, "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), + "ss_bucket_info": json.dumps(train_dataset.bucket_info), "ss_training_comment": args.training_comment # will not be updated after training } @@ -437,8 +440,8 @@ if __name__ == '__main__': train_util.add_training_arguments(parser, True) parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") - parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)") + parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)") parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")