From 5e817e4343cdf839096b430877609c6f52749a30 Mon Sep 17 00:00:00 2001 From: forestsource Date: Sun, 22 Jan 2023 02:57:12 +0900 Subject: [PATCH 1/6] Add save_n_epoch_ratio --- fine_tune.py | 2 ++ library/train_util.py | 2 ++ train_db.py | 2 ++ train_network.py | 2 ++ 4 files changed, 8 insertions(+) diff --git a/fine_tune.py b/fine_tune.py index 02f665bd..8e615203 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -200,6 +200,8 @@ def train(args): # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps diff --git a/library/train_util.py b/library/train_util.py index aa65dc3c..5ff0280e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1028,6 +1028,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") + parser.add_argument("--save_n_epoch_ratio", type=int, default=None, + help="save checkpoint N epoch ratio / 学習中のモデルを指定のエポック割合で保存する") parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)") parser.add_argument("--save_state", action="store_true", diff --git a/train_db.py b/train_db.py index 8ac503ea..fe6fd4e6 100644 --- a/train_db.py +++ b/train_db.py @@ -176,6 +176,8 @@ def train(args): # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps diff --git a/train_network.py b/train_network.py index b2c7b579..d3282da9 100644 --- a/train_network.py +++ b/train_network.py @@ -192,6 +192,8 @@ def train(args): # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps From f7fbdc4b2aa52986cdab2e5482ba840457c6428f Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Mon, 23 Jan 2023 17:21:04 -0800 Subject: [PATCH 2/6] Precalculate .safetensors model hashes after training --- library/train_util.py | 45 +++++++++++++++++++++++++++++++++++++++++++ networks/lora.py | 10 ++++++++++ 2 files changed, 55 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 0fdbadc1..bbc68aae 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -12,6 +12,7 @@ import math import os import random import hashlib +from io import BytesIO from tqdm import tqdm import torch @@ -25,6 +26,7 @@ from PIL import Image import cv2 from einops import rearrange from torch import einsum +import safetensors.torch import library.model_util as model_util @@ -790,6 +792,49 @@ def calculate_sha256(filename): return hash_sha256.hexdigest() +def precalculate_safetensors_hashes(tensors, metadata): + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" + + # Because writing user metadata to the file can change the result of + # sd_models.model_hash(), only retain the training metadata for purposes of + # calculating the hash, as they are meant to be immutable + metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} + + bytes = safetensors.torch.save(tensors, metadata) + b = BytesIO(bytes) + + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + return model_hash, legacy_hash + + +def addnet_hash_legacy(b): + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() + + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] + + +def addnet_hash_safetensors(b): + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + # flash attention forwards and backwards # https://arxiv.org/abs/2205.14135 diff --git a/networks/lora.py b/networks/lora.py index 9243f1e1..bbc65164 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -7,6 +7,8 @@ import math import os import torch +from library import train_util + class LoRAModule(torch.nn.Module): """ @@ -221,6 +223,14 @@ class LoRANetwork(torch.nn.Module): if os.path.splitext(file)[1] == '.safetensors': from safetensors.torch import save_file + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + save_file(state_dict, file, metadata) else: torch.save(state_dict, file) From 66051883fb93017d50b17d182b4d3d2e281741a9 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Mon, 23 Jan 2023 17:26:58 -0800 Subject: [PATCH 3/6] Add bucketing metadata --- library/train_util.py | 4 ++++ train_network.py | 1 + 2 files changed, 5 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 0fdbadc1..e63ee828 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -85,6 +85,7 @@ class BaseDataset(torch.utils.data.Dataset): self.enable_bucket = False self.min_bucket_reso = None self.max_bucket_reso = None + self.bucket_info = None self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 @@ -217,9 +218,12 @@ class BaseDataset(torch.utils.data.Dataset): self.buckets[bucket_index].append(image_info.image_key) if self.enable_bucket: + self.bucket_info = {"buckets": {}} print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)): + self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)} print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}") + self.bucket_info["img_ar_errors"] = img_ar_errors img_ar_errors = np.array(img_ar_errors) print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}") diff --git a/train_network.py b/train_network.py index d60ae9a0..5eada8f1 100644 --- a/train_network.py +++ b/train_network.py @@ -264,6 +264,7 @@ def train(args): "ss_keep_tokens": args.keep_tokens, "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), + "ss_bucket_info": json.dumps(train_dataset.bucket_info), "ss_training_comment": args.training_comment # will not be updated after training } From bf3a13bb4e4c4d45f2bedd3fbb752f33b3ec907b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 24 Jan 2023 18:57:21 +0900 Subject: [PATCH 4/6] Fix error for loading bf16 weights --- networks/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/lora.py b/networks/lora.py index 9243f1e1..b936bfb2 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -31,7 +31,7 @@ class LoRAModule(torch.nn.Module): self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) if type(alpha) == torch.Tensor: - alpha = alpha.detach().numpy() + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える From 9f644d8dc3cc435cf64fb3e2f7a169ac173410f2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 24 Jan 2023 20:16:21 +0900 Subject: [PATCH 5/6] Change default save format to safetensors --- train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 07aed4eb..8a8acc7d 100644 --- a/train_network.py +++ b/train_network.py @@ -440,8 +440,8 @@ if __name__ == '__main__': train_util.add_training_arguments(parser, True) parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") - parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)") + parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)") parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") From 91a50ea63734b548ae593a474c5248aa8307f5c0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 24 Jan 2023 20:17:15 +0900 Subject: [PATCH 6/6] Change img_ar_errors to mean because too many imgs --- library/train_util.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 55a9aacd..f967c5f8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -225,9 +225,12 @@ class BaseDataset(torch.utils.data.Dataset): for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)): self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)} print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}") - self.bucket_info["img_ar_errors"] = img_ar_errors + img_ar_errors = np.array(img_ar_errors) - print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}") + mean_img_ar_error = np.mean(np.abs(img_ar_errors)) + self.bucket_info["mean_img_ar_error"] = mean_img_ar_error + print(f"mean ar error (without repeats): {mean_img_ar_error}") + # 参照用indexを作る self.buckets_indices: list(BucketBatchIndex) = [] @@ -834,7 +837,7 @@ def addnet_hash_safetensors(b): offset = n + 8 b.seek(offset) for chunk in iter(lambda: b.read(blksize), b""): - hash_sha256.update(chunk) + hash_sha256.update(chunk) return hash_sha256.hexdigest() @@ -1107,7 +1110,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") parser.add_argument("--save_n_epoch_ratio", type=int, default=None, - help="save checkpoint N epoch ratio / 学習中のモデルを指定のエポック割合で保存する") + help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)") parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")