fix to work with num_workers=0

This commit is contained in:
Kohya S
2023-03-28 19:42:47 +09:00
parent 99eaf1fd65
commit 4f70e5dca6
5 changed files with 27 additions and 13 deletions

View File

@@ -3057,12 +3057,20 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
# collate_fn用 epoch,stepはmultiprocessing.Value
class collater_class:
def __init__(self, epoch, step):
def __init__(self, epoch, step, dataset):
self.current_epoch = epoch
self.current_step = step
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
def __call__(self, examples):
dataset = torch.utils.data.get_worker_info().dataset
worker_info = torch.utils.data.get_worker_info()
# worker_info is None in the main process
if worker_info is not None:
dataset = worker_info.dataset
else:
dataset = self.dataset
# set epoch and step
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]