lora以外も対応

This commit is contained in:
u-haru
2023-03-26 02:19:55 +09:00
parent 9c80da6ac5
commit 4dc1124f93
5 changed files with 37 additions and 33 deletions

View File

@@ -2987,3 +2987,14 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
# endregion
# colalte_fn用 epoch,stepはmultiprocessing.Value
class collater_class:
def __init__(self,epoch,step):
self.current_epoch=epoch
self.current_step=step
def __call__(self, examples):
dataset = torch.utils.data.get_worker_info().dataset
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]