mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix to work with num_workers=0
This commit is contained in:
@@ -3057,12 +3057,20 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
|
||||
|
||||
# collate_fn用 epoch,stepはmultiprocessing.Value
|
||||
class collater_class:
|
||||
def __init__(self, epoch, step):
|
||||
def __init__(self, epoch, step, dataset):
|
||||
self.current_epoch = epoch
|
||||
self.current_step = step
|
||||
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
|
||||
|
||||
def __call__(self, examples):
|
||||
dataset = torch.utils.data.get_worker_info().dataset
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
# worker_info is None in the main process
|
||||
if worker_info is not None:
|
||||
dataset = worker_info.dataset
|
||||
else:
|
||||
dataset = self.dataset
|
||||
|
||||
# set epoch and step
|
||||
dataset.set_current_epoch(self.current_epoch.value)
|
||||
dataset.set_current_step(self.current_step.value)
|
||||
return examples[0]
|
||||
|
||||
Reference in New Issue
Block a user