diff --git a/library/train_util.py b/library/train_util.py index 844faca7..e1046d58 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1518,6 +1518,67 @@ def glob_images_pathlib(dir_path, recursive): return image_paths +class MinimalDataset(BaseDataset): + def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False): + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + + self.num_train_images = 0 # update in subclass + self.num_reg_images = 0 # update in subclass + self.datasets = [self] + self.batch_size = 1 # update in subclass + + self.subsets = [self] + self.num_repeats = 1 # update in subclass if needed + self.img_count = 1 # update in subclass if needed + self.bucket_info = {} + self.is_reg = False + self.image_dir = "dummy" # for metadata + + def is_latent_cacheable(self) -> bool: + return False + + def __len__(self): + raise NotImplementedError + + # override to avoid shuffling buckets + def set_current_epoch(self, epoch): + self.current_epoch = epoch + + def __getitem__(self, idx): + r""" + The subclass may have image_data for debug_dataset, which is a dict of ImageInfo objects. + + Returns: example like this: + + for i in range(batch_size): + image_key = ... # whatever hashable + image_keys.append(image_key) + + image = ... # PIL Image + img_tensor = self.image_transforms(img) + images.append(img_tensor) + + caption = ... # str + input_ids = self.get_input_ids(caption) + input_ids_list.append(input_ids) + + captions.append(caption) + + images = torch.stack(images, dim=0) + input_ids_list = torch.stack(input_ids_list, dim=0) + example = { + "images": images, + "input_ids": input_ids_list, + "captions": captions, # for debug_dataset + "latents": None, + "image_keys": image_keys, # for debug_dataset + "loss_weights": torch.ones(batch_size, dtype=torch.float32), + } + return example + """ + raise NotImplementedError + + # endregion # region モジュール入れ替え部 diff --git a/train_network.py b/train_network.py index b62aef7e..6f845b5a 100644 --- a/train_network.py +++ b/train_network.py @@ -92,42 +92,56 @@ def train(args): tokenizer = train_util.load_tokenizer(args) # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) - if use_user_config: - print(f"Loading dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - print( - "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) + if use_user_config: + print(f"Loading dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) ) - ) - else: - if use_dreambooth_method: - print("Using DreamBooth method.") - user_config = { - "datasets": [ - {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} - ] - } else: - print("Training with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } + if use_dreambooth_method: + print("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + print("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + # use arbitrary dataset class + module = ".".join(args.dataset_class.split(".")[:-1]) + dataset_class = args.dataset_class.split(".")[-1] + module = importlib.import_module(module) + dataset_class = getattr(module, dataset_class) + train_dataset_group: train_util.MinimalDataset = dataset_class( + tokenizer, args.max_token_length, args.resolution, args.debug_dataset + ) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -185,6 +199,7 @@ def train(args): module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") print(f"all weights merged: {', '.join(args.base_weights)}") + # 学習を準備する if cache_latents: vae.to(accelerator.device, dtype=weight_dtype) @@ -852,6 +867,12 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", ) + parser.add_argument( + "--dataset_class", + type=str, + default=None, + help="dataset class for arbitrary dataset / 任意のデータセットのクラス名", + ) return parser