From 91a50ea63734b548ae593a474c5248aa8307f5c0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 24 Jan 2023 20:17:15 +0900 Subject: [PATCH] Change img_ar_errors to mean because too many imgs --- library/train_util.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 55a9aacd..f967c5f8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -225,9 +225,12 @@ class BaseDataset(torch.utils.data.Dataset): for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)): self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)} print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}") - self.bucket_info["img_ar_errors"] = img_ar_errors + img_ar_errors = np.array(img_ar_errors) - print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}") + mean_img_ar_error = np.mean(np.abs(img_ar_errors)) + self.bucket_info["mean_img_ar_error"] = mean_img_ar_error + print(f"mean ar error (without repeats): {mean_img_ar_error}") + # 参照用indexを作る self.buckets_indices: list(BucketBatchIndex) = [] @@ -834,7 +837,7 @@ def addnet_hash_safetensors(b): offset = n + 8 b.seek(offset) for chunk in iter(lambda: b.read(blksize), b""): - hash_sha256.update(chunk) + hash_sha256.update(chunk) return hash_sha256.hexdigest() @@ -1107,7 +1110,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") parser.add_argument("--save_n_epoch_ratio", type=int, default=None, - help="save checkpoint N epoch ratio / 学習中のモデルを指定のエポック割合で保存する") + help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)") parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")