diff --git a/library/train_util.py b/library/train_util.py index 0fdbadc1..e63ee828 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -85,6 +85,7 @@ class BaseDataset(torch.utils.data.Dataset): self.enable_bucket = False self.min_bucket_reso = None self.max_bucket_reso = None + self.bucket_info = None self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 @@ -217,9 +218,12 @@ class BaseDataset(torch.utils.data.Dataset): self.buckets[bucket_index].append(image_info.image_key) if self.enable_bucket: + self.bucket_info = {"buckets": {}} print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)): + self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)} print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}") + self.bucket_info["img_ar_errors"] = img_ar_errors img_ar_errors = np.array(img_ar_errors) print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}") diff --git a/train_network.py b/train_network.py index d60ae9a0..5eada8f1 100644 --- a/train_network.py +++ b/train_network.py @@ -264,6 +264,7 @@ def train(args): "ss_keep_tokens": args.keep_tokens, "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), + "ss_bucket_info": json.dumps(train_dataset.bucket_info), "ss_training_comment": args.training_comment # will not be updated after training }