diff --git a/library/train_util.py b/library/train_util.py index 4ec26770..24e15d1f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -432,17 +432,25 @@ class BaseDataset(torch.utils.data.Dataset): # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる self.buckets_indices: List(BucketBatchIndex) = [] for bucket_index, bucket in enumerate(self.bucket_manager.buckets): - # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは - # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう - # そのためバッチサイズを画像種類までに制限する - # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない? - # TODO 正則化画像をepochまたがりで利用する仕組み - num_of_image_types = len(set(bucket)) - bucket_batch_size = min(self.batch_size, num_of_image_types) - batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) - # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count) + batch_count = int(math.ceil(len(bucket) / self.batch_size)) for batch_index in range(batch_count): - self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) + self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) + + # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す + #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる + # + # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは + # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう + # # そのためバッチサイズを画像種類までに制限する + # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない? + # # TO DO 正則化画像をepochまたがりで利用する仕組み + # num_of_image_types = len(set(bucket)) + # bucket_batch_size = min(self.batch_size, num_of_image_types) + # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) + # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count) + # for batch_index in range(batch_count): + # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) + # ↑ここまで self.shuffle_buckets() self._length = len(self.buckets_indices)