mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
revert batch size limiting for bucket
This commit is contained in:
@@ -432,17 +432,25 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
|
||||
self.buckets_indices: List(BucketBatchIndex) = []
|
||||
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
|
||||
# bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
|
||||
# ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
|
||||
# そのためバッチサイズを画像種類までに制限する
|
||||
# ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
|
||||
# TODO 正則化画像をepochまたがりで利用する仕組み
|
||||
num_of_image_types = len(set(bucket))
|
||||
bucket_batch_size = min(self.batch_size, num_of_image_types)
|
||||
batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
|
||||
# print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
|
||||
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
||||
for batch_index in range(batch_count):
|
||||
self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
|
||||
self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
|
||||
|
||||
# ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
|
||||
# 学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
|
||||
#
|
||||
# # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
|
||||
# # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
|
||||
# # そのためバッチサイズを画像種類までに制限する
|
||||
# # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
|
||||
# # TO DO 正則化画像をepochまたがりで利用する仕組み
|
||||
# num_of_image_types = len(set(bucket))
|
||||
# bucket_batch_size = min(self.batch_size, num_of_image_types)
|
||||
# batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
|
||||
# # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
|
||||
# for batch_index in range(batch_count):
|
||||
# self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
|
||||
# ↑ここまで
|
||||
|
||||
self.shuffle_buckets()
|
||||
self._length = len(self.buckets_indices)
|
||||
|
||||
Reference in New Issue
Block a user