mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge pull request #103 from space-nuko/bucketing-metadata
Add bucketing metadata
This commit is contained in:
@@ -87,6 +87,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.enable_bucket = False
|
||||
self.min_bucket_reso = None
|
||||
self.max_bucket_reso = None
|
||||
self.bucket_info = None
|
||||
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||
|
||||
@@ -219,9 +220,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.buckets[bucket_index].append(image_info.image_key)
|
||||
|
||||
if self.enable_bucket:
|
||||
self.bucket_info = {"buckets": {}}
|
||||
print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
|
||||
for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)):
|
||||
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)}
|
||||
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")
|
||||
self.bucket_info["img_ar_errors"] = img_ar_errors
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user