Add bucketing metadata

This commit is contained in:
space-nuko
2023-01-23 17:26:58 -08:00
parent 93df55d597
commit 66051883fb
2 changed files with 5 additions and 0 deletions

View File

@@ -85,6 +85,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.enable_bucket = False
self.min_bucket_reso = None
self.max_bucket_reso = None
self.bucket_info = None
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
@@ -217,9 +218,12 @@ class BaseDataset(torch.utils.data.Dataset):
self.buckets[bucket_index].append(image_info.image_key)
if self.enable_bucket:
self.bucket_info = {"buckets": {}}
print("number of images (including repeats) / 各bucketの画像枚数繰り返し回数を含む")
for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)):
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)}
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")
self.bucket_info["img_ar_errors"] = img_ar_errors
img_ar_errors = np.array(img_ar_errors)
print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}")