Final implementation

This commit is contained in:
Kohaku-Blueleaf
2024-05-31 12:20:20 +08:00
parent 0d96e10b3e
commit b2363f1021
2 changed files with 106 additions and 9 deletions

View File

@@ -657,8 +657,15 @@ class BaseDataset(torch.utils.data.Dataset):
def set_current_epoch(self, epoch):
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
self.shuffle_buckets()
self.current_epoch = epoch
if epoch > self.current_epoch:
logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
num_epochs = epoch - self.current_epoch
for _ in range(num_epochs):
self.current_epoch += 1
self.shuffle_buckets()
else:
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
self.current_epoch = epoch
def set_current_step(self, step):
self.current_step = step