mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support arbitrary dataset for train_network.py
This commit is contained in:
@@ -1518,6 +1518,67 @@ def glob_images_pathlib(dir_path, recursive):
|
|||||||
return image_paths
|
return image_paths
|
||||||
|
|
||||||
|
|
||||||
|
class MinimalDataset(BaseDataset):
|
||||||
|
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
|
||||||
|
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||||
|
|
||||||
|
self.num_train_images = 0 # update in subclass
|
||||||
|
self.num_reg_images = 0 # update in subclass
|
||||||
|
self.datasets = [self]
|
||||||
|
self.batch_size = 1 # update in subclass
|
||||||
|
|
||||||
|
self.subsets = [self]
|
||||||
|
self.num_repeats = 1 # update in subclass if needed
|
||||||
|
self.img_count = 1 # update in subclass if needed
|
||||||
|
self.bucket_info = {}
|
||||||
|
self.is_reg = False
|
||||||
|
self.image_dir = "dummy" # for metadata
|
||||||
|
|
||||||
|
def is_latent_cacheable(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# override to avoid shuffling buckets
|
||||||
|
def set_current_epoch(self, epoch):
|
||||||
|
self.current_epoch = epoch
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
r"""
|
||||||
|
The subclass may have image_data for debug_dataset, which is a dict of ImageInfo objects.
|
||||||
|
|
||||||
|
Returns: example like this:
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
image_key = ... # whatever hashable
|
||||||
|
image_keys.append(image_key)
|
||||||
|
|
||||||
|
image = ... # PIL Image
|
||||||
|
img_tensor = self.image_transforms(img)
|
||||||
|
images.append(img_tensor)
|
||||||
|
|
||||||
|
caption = ... # str
|
||||||
|
input_ids = self.get_input_ids(caption)
|
||||||
|
input_ids_list.append(input_ids)
|
||||||
|
|
||||||
|
captions.append(caption)
|
||||||
|
|
||||||
|
images = torch.stack(images, dim=0)
|
||||||
|
input_ids_list = torch.stack(input_ids_list, dim=0)
|
||||||
|
example = {
|
||||||
|
"images": images,
|
||||||
|
"input_ids": input_ids_list,
|
||||||
|
"captions": captions, # for debug_dataset
|
||||||
|
"latents": None,
|
||||||
|
"image_keys": image_keys, # for debug_dataset
|
||||||
|
"loss_weights": torch.ones(batch_size, dtype=torch.float32),
|
||||||
|
}
|
||||||
|
return example
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region モジュール入れ替え部
|
# region モジュール入れ替え部
|
||||||
|
|||||||
@@ -92,42 +92,56 @@ def train(args):
|
|||||||
tokenizer = train_util.load_tokenizer(args)
|
tokenizer = train_util.load_tokenizer(args)
|
||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
if args.dataset_class is None:
|
||||||
if use_user_config:
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
||||||
print(f"Loading dataset config from {args.dataset_config}")
|
if use_user_config:
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
print(f"Loading dataset config from {args.dataset_config}")
|
||||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
if any(getattr(args, attr) is not None for attr in ignored):
|
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||||
print(
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
print(
|
||||||
", ".join(ignored)
|
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
|
", ".join(ignored)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
if use_dreambooth_method:
|
|
||||||
print("Using DreamBooth method.")
|
|
||||||
user_config = {
|
|
||||||
"datasets": [
|
|
||||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
print("Training with captions.")
|
if use_dreambooth_method:
|
||||||
user_config = {
|
print("Using DreamBooth method.")
|
||||||
"datasets": [
|
user_config = {
|
||||||
{
|
"datasets": [
|
||||||
"subsets": [
|
{
|
||||||
{
|
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||||
"image_dir": args.train_data_dir,
|
args.train_data_dir, args.reg_data_dir
|
||||||
"metadata_file": args.in_json,
|
)
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
else:
|
||||||
}
|
print("Training with captions.")
|
||||||
|
user_config = {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"subsets": [
|
||||||
|
{
|
||||||
|
"image_dir": args.train_data_dir,
|
||||||
|
"metadata_file": args.in_json,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
else:
|
||||||
|
# use arbitrary dataset class
|
||||||
|
module = ".".join(args.dataset_class.split(".")[:-1])
|
||||||
|
dataset_class = args.dataset_class.split(".")[-1]
|
||||||
|
module = importlib.import_module(module)
|
||||||
|
dataset_class = getattr(module, dataset_class)
|
||||||
|
train_dataset_group: train_util.MinimalDataset = dataset_class(
|
||||||
|
tokenizer, args.max_token_length, args.resolution, args.debug_dataset
|
||||||
|
)
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
@@ -185,6 +199,7 @@ def train(args):
|
|||||||
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
|
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
|
||||||
|
|
||||||
print(f"all weights merged: {', '.join(args.base_weights)}")
|
print(f"all weights merged: {', '.join(args.base_weights)}")
|
||||||
|
|
||||||
# 学習を準備する
|
# 学習を準備する
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
@@ -852,6 +867,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
nargs="*",
|
nargs="*",
|
||||||
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
|
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_class",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="dataset class for arbitrary dataset / 任意のデータセットのクラス名",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user