support arbitrary dataset for train_network.py

This commit is contained in:
Kohya S
2023-06-14 12:49:12 +09:00
parent 8088c04a71
commit 9aee793078
2 changed files with 115 additions and 33 deletions

View File

@@ -1518,6 +1518,67 @@ def glob_images_pathlib(dir_path, recursive):
return image_paths
class MinimalDataset(BaseDataset):
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
self.num_train_images = 0 # update in subclass
self.num_reg_images = 0 # update in subclass
self.datasets = [self]
self.batch_size = 1 # update in subclass
self.subsets = [self]
self.num_repeats = 1 # update in subclass if needed
self.img_count = 1 # update in subclass if needed
self.bucket_info = {}
self.is_reg = False
self.image_dir = "dummy" # for metadata
def is_latent_cacheable(self) -> bool:
return False
def __len__(self):
raise NotImplementedError
# override to avoid shuffling buckets
def set_current_epoch(self, epoch):
self.current_epoch = epoch
def __getitem__(self, idx):
r"""
The subclass may have image_data for debug_dataset, which is a dict of ImageInfo objects.
Returns: example like this:
for i in range(batch_size):
image_key = ... # whatever hashable
image_keys.append(image_key)
image = ... # PIL Image
img_tensor = self.image_transforms(img)
images.append(img_tensor)
caption = ... # str
input_ids = self.get_input_ids(caption)
input_ids_list.append(input_ids)
captions.append(caption)
images = torch.stack(images, dim=0)
input_ids_list = torch.stack(input_ids_list, dim=0)
example = {
"images": images,
"input_ids": input_ids_list,
"captions": captions, # for debug_dataset
"latents": None,
"image_keys": image_keys, # for debug_dataset
"loss_weights": torch.ones(batch_size, dtype=torch.float32),
}
return example
"""
raise NotImplementedError
# endregion
# region モジュール入れ替え部