support arbitrary dataset for train_network.py

This commit is contained in:
Kohya S
2023-06-14 12:49:12 +09:00
parent 8088c04a71
commit 9aee793078
2 changed files with 115 additions and 33 deletions

View File

@@ -1518,6 +1518,67 @@ def glob_images_pathlib(dir_path, recursive):
return image_paths return image_paths
class MinimalDataset(BaseDataset):
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
self.num_train_images = 0 # update in subclass
self.num_reg_images = 0 # update in subclass
self.datasets = [self]
self.batch_size = 1 # update in subclass
self.subsets = [self]
self.num_repeats = 1 # update in subclass if needed
self.img_count = 1 # update in subclass if needed
self.bucket_info = {}
self.is_reg = False
self.image_dir = "dummy" # for metadata
def is_latent_cacheable(self) -> bool:
return False
def __len__(self):
raise NotImplementedError
# override to avoid shuffling buckets
def set_current_epoch(self, epoch):
self.current_epoch = epoch
def __getitem__(self, idx):
r"""
The subclass may have image_data for debug_dataset, which is a dict of ImageInfo objects.
Returns: example like this:
for i in range(batch_size):
image_key = ... # whatever hashable
image_keys.append(image_key)
image = ... # PIL Image
img_tensor = self.image_transforms(img)
images.append(img_tensor)
caption = ... # str
input_ids = self.get_input_ids(caption)
input_ids_list.append(input_ids)
captions.append(caption)
images = torch.stack(images, dim=0)
input_ids_list = torch.stack(input_ids_list, dim=0)
example = {
"images": images,
"input_ids": input_ids_list,
"captions": captions, # for debug_dataset
"latents": None,
"image_keys": image_keys, # for debug_dataset
"loss_weights": torch.ones(batch_size, dtype=torch.float32),
}
return example
"""
raise NotImplementedError
# endregion # endregion
# region モジュール入れ替え部 # region モジュール入れ替え部

View File

@@ -92,42 +92,56 @@ def train(args):
tokenizer = train_util.load_tokenizer(args) tokenizer = train_util.load_tokenizer(args)
# データセットを準備する # データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) if args.dataset_class is None:
if use_user_config: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
print(f"Loading dataset config from {args.dataset_config}") if use_user_config:
user_config = config_util.load_user_config(args.dataset_config) print(f"Loading dataset config from {args.dataset_config}")
ignored = ["train_data_dir", "reg_data_dir", "in_json"] user_config = config_util.load_user_config(args.dataset_config)
if any(getattr(args, attr) is not None for attr in ignored): ignored = ["train_data_dir", "reg_data_dir", "in_json"]
print( if any(getattr(args, attr) is not None for attr in ignored):
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( print(
", ".join(ignored) "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
) )
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
else: else:
print("Training with captions.") if use_dreambooth_method:
user_config = { print("Using DreamBooth method.")
"datasets": [ user_config = {
{ "datasets": [
"subsets": [ {
{ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
"image_dir": args.train_data_dir, args.train_data_dir, args.reg_data_dir
"metadata_file": args.in_json, )
} }
] ]
} }
] else:
} print("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
# use arbitrary dataset class
module = ".".join(args.dataset_class.split(".")[:-1])
dataset_class = args.dataset_class.split(".")[-1]
module = importlib.import_module(module)
dataset_class = getattr(module, dataset_class)
train_dataset_group: train_util.MinimalDataset = dataset_class(
tokenizer, args.max_token_length, args.resolution, args.debug_dataset
)
current_epoch = Value("i", 0) current_epoch = Value("i", 0)
current_step = Value("i", 0) current_step = Value("i", 0)
@@ -185,6 +199,7 @@ def train(args):
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
print(f"all weights merged: {', '.join(args.base_weights)}") print(f"all weights merged: {', '.join(args.base_weights)}")
# 学習を準備する # 学習を準備する
if cache_latents: if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
@@ -852,6 +867,12 @@ def setup_parser() -> argparse.ArgumentParser:
nargs="*", nargs="*",
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
) )
parser.add_argument(
"--dataset_class",
type=str,
default=None,
help="dataset class for arbitrary dataset / 任意のデータセットのクラス名",
)
return parser return parser