From 62d00b4520aaae6076474389c9f61db2c982c9e2 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Wed, 31 May 2023 14:13:15 +0900 Subject: [PATCH] add controlnet training --- library/config_util.py | 97 ++++++- library/model_util.py | 76 ++++++ library/train_util.py | 318 ++++++++++++++++++++++ train_controlnet.py | 594 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 1075 insertions(+), 10 deletions(-) create mode 100644 train_controlnet.py diff --git a/library/config_util.py b/library/config_util.py index 98b41751..ae17655c 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -33,8 +33,10 @@ from . import train_util from .train_util import ( DreamBoothSubset, FineTuningSubset, + ControlNetSubset, DreamBoothDataset, FineTuningDataset, + ControlNetDataset, DatasetGroup, ) @@ -70,6 +72,11 @@ class DreamBoothSubsetParams(BaseSubsetParams): class FineTuningSubsetParams(BaseSubsetParams): metadata_file: Optional[str] = None +@dataclass +class ControlNetSubsetParams(BaseSubsetParams): + conditioning_data_dir: str = None + caption_extension: str = ".caption" + @dataclass class BaseDatasetParams: tokenizer: CLIPTokenizer = None @@ -96,6 +103,15 @@ class FineTuningDatasetParams(BaseDatasetParams): bucket_reso_steps: int = 64 bucket_no_upscale: bool = False +@dataclass +class ControlNetDatasetParams(BaseDatasetParams): + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + @dataclass class SubsetBlueprint: params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] @@ -103,6 +119,7 @@ class SubsetBlueprint: @dataclass class DatasetBlueprint: is_dreambooth: bool + is_controlnet: bool params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] subsets: Sequence[SubsetBlueprint] @@ -163,6 +180,13 @@ class ConfigSanitizer: Required("metadata_file"): str, "image_dir": str, } + CN_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + } + CN_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + Required("conditioning_data_dir"): str, + } # datasets schema DATASET_ASCENDABLE_SCHEMA = { @@ -192,8 +216,8 @@ class ConfigSanitizer: "dataset_repeats": "num_repeats", } - def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None: - assert support_dreambooth or support_finetuning, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: + assert support_dreambooth or support_finetuning or support_controlnet, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" self.db_subset_schema = self.__merge_dict( self.SUBSET_ASCENDABLE_SCHEMA, @@ -208,6 +232,13 @@ class ConfigSanitizer: self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, ) + self.cn_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_DISTINCT_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + self.db_dataset_schema = self.__merge_dict( self.DATASET_ASCENDABLE_SCHEMA, self.SUBSET_ASCENDABLE_SCHEMA, @@ -223,13 +254,23 @@ class ConfigSanitizer: {"subsets": [self.ft_subset_schema]}, ) - if support_dreambooth and support_finetuning: + self.cn_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.cn_subset_schema]}, + ) + + if support_dreambooth and support_finetuning and support_controlnet: def validate_flex_dataset(dataset_config: dict): subsets_config = dataset_config.get("subsets", []) + if all(["conditioning_data_dir" in subset for subset in subsets_config]): + return Schema(self.cn_dataset_schema)(dataset_config) # check dataset meets FT style # NOTE: all FT subsets should have "metadata_file" - if all(["metadata_file" in subset for subset in subsets_config]): + elif all(["metadata_file" in subset for subset in subsets_config]): return Schema(self.ft_dataset_schema)(dataset_config) # check dataset meets DB style # NOTE: all DB subsets should have no "metadata_file" @@ -241,13 +282,16 @@ class ConfigSanitizer: self.dataset_schema = validate_flex_dataset elif support_dreambooth: self.dataset_schema = self.db_dataset_schema - else: + elif support_finetuning: self.dataset_schema = self.ft_dataset_schema + elif support_controlnet: + self.dataset_schema = self.cn_dataset_schema self.general_schema = self.__merge_dict( self.DATASET_ASCENDABLE_SCHEMA, self.SUBSET_ASCENDABLE_SCHEMA, self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, + self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {}, self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, ) @@ -318,7 +362,11 @@ class BlueprintGenerator: # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets subsets = dataset_config.get("subsets", []) is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) - if is_dreambooth: + is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets]) + if is_controlnet: + subset_params_klass = ControlNetSubsetParams + dataset_params_klass = ControlNetDatasetParams + elif is_dreambooth: subset_params_klass = DreamBoothSubsetParams dataset_params_klass = DreamBoothDatasetParams else: @@ -333,7 +381,7 @@ class BlueprintGenerator: params = self.generate_params_by_fallbacks(dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]) - dataset_blueprints.append(DatasetBlueprint(is_dreambooth, params, subset_blueprints)) + dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints)) dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) @@ -361,10 +409,13 @@ class BlueprintGenerator: def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): - datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = [] + datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: - if dataset_blueprint.is_dreambooth: + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: subset_klass = DreamBoothSubset dataset_klass = DreamBoothDataset else: @@ -379,6 +430,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu info = "" for i, dataset in enumerate(datasets): is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) info += dedent(f"""\ [Dataset {i}] batch_size: {dataset.batch_size} @@ -421,7 +473,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu class_tokens: {subset.class_tokens} caption_extension: {subset.caption_extension} \n"""), " ") - else: + elif not is_controlnet: info += indent(dedent(f"""\ metadata_file: {subset.metadata_file} \n"""), " ") @@ -479,6 +531,31 @@ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] return subsets_config +def generate_controlnet_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"): + def generate(base_dir: Optional[str]): + if base_dir is None: + return [] + + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] + + subsets_config = [] + for subdir in base_dir.iterdir(): + if not subdir.is_dir(): + continue + + subset_config = {"image_dir": str(subdir), "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1} + subsets_config.append(subset_config) + + return subsets_config + + subsets_config = [] + subsets_config += generate(train_data_dir, False) + + return subsets_config + + def load_user_config(file: str) -> dict: file: Path = Path(file) if not file.is_file(): diff --git a/library/model_util.py b/library/model_util.py index 26f72235..bb168653 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -732,6 +732,82 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict): return new_state_dict +def convert_controlnet_state_dict_to_sd(controlnet_state_dict): + unet_conversion_map = [ + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("middle_block_out.0.weight", "controlnet_mid_block.weight"), + ("middle_block_out.0.bias", "controlnet_mid_block.bias"), + ] + + unet_conversion_map_resnet = [ + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + ] + + unet_conversion_map_layer = [] + for i in range(4): + for j in range(2): + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + controlnet_cond_embedding_names = ( + ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"] + ) + for i, hf_prefix in enumerate(controlnet_cond_embedding_names): + hf_prefix = f"controlnet_cond_embedding.{hf_prefix}." + sd_prefix = f"input_hint_block.{i*2}." + unet_conversion_map_layer.append((sd_prefix, hf_prefix)) + + for i in range(12): + hf_prefix = f"controlnet_down_blocks.{i}." + sd_prefix = f"zero_convs.{i}.0." + unet_conversion_map_layer.append((sd_prefix, hf_prefix)) + + mapping = {k: k for k in controlnet_state_dict.keys()} + for sd_name, diffusers_name in unet_conversion_map: + mapping[diffusers_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, diffusers_part in unet_conversion_map_resnet: + v = v.replace(diffusers_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, diffusers_part in unet_conversion_map_layer: + v = v.replace(diffusers_part, sd_part) + mapping[k] = v + new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()} + return new_state_dict + + # ================# # VAE Conversion # # ================# diff --git a/library/train_util.py b/library/train_util.py index 008ccd64..1921c2a4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -403,6 +403,54 @@ class FineTuningSubset(BaseSubset): return self.metadata_file == other.metadata_file +class ControlNetSubset(BaseSubset): + def __init__( + self, + image_dir: str, + conditioning_data_dir: str, + caption_extension: str, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, + ) -> None: + assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" + + super().__init__( + image_dir, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, + ) + + self.conditioning_data_dir = conditioning_data_dir + self.caption_extension = caption_extension + if self.caption_extension and not self.caption_extension.startswith("."): + self.caption_extension = "." + self.caption_extension + + def __eq__(self, other) -> bool: + if not isinstance(other, ControlNetSubset): + return NotImplemented + return self.image_dir == other.image_dir and self.conditioning_data_dir == other.conditioning_data_dir + + class BaseDataset(torch.utils.data.Dataset): def __init__( self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool @@ -1387,6 +1435,274 @@ class FineTuningDataset(BaseDataset): return npz_file_norm, npz_file_flip +class ControlNetDataset(BaseDataset): + def __init__( + self, + subsets: Sequence[ControlNetSubset], + batch_size: int, + tokenizer, + max_token_length, + resolution, + enable_bucket: bool, + min_bucket_reso: int, + max_bucket_reso: int, + bucket_reso_steps: int, + bucket_no_upscale: bool, + debug_dataset) -> None: + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + self.conditioning_image_data: Dict[str, ImageInfo] = {} + + assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" + + self.batch_size = batch_size + self.size = min(self.width, self.height) # 短いほう + self.latents_cache = None + + self.num_reg_images = 0 + + self.enable_bucket = enable_bucket + if self.enable_bucket: + assert ( + min(resolution) >= min_bucket_reso + ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" + assert ( + max(resolution) <= max_bucket_reso + ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso + self.bucket_reso_steps = bucket_reso_steps + self.bucket_no_upscale = bucket_no_upscale + else: + self.min_bucket_reso = None + self.max_bucket_reso = None + self.bucket_reso_steps = None # この情報は使われない + self.bucket_no_upscale = False + + def read_caption(img_path, caption_extension): + # captionの候補ファイル名を作る + base_name = os.path.splitext(img_path)[0] + base_name_face_det = base_name + tokens = base_name.split("_") + if len(tokens) >= 5: + base_name_face_det = "_".join(tokens[:-4]) + cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] + + caption = None + for cap_path in cap_paths: + if os.path.isfile(cap_path): + with open(cap_path, "rt", encoding="utf-8") as f: + try: + lines = f.readlines() + except UnicodeDecodeError as e: + print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + raise e + assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" + caption = lines[0].strip() + break + return caption + + def load_controlnet_dir(subset: ControlNetSubset): + if not os.path.isdir(subset.image_dir): + print(f"not directory: {subset.image_dir}") + return [], [] + if not os.path.isdir(subset.conditioning_data_dir): + print(f"not directory: {subset.conditioning_data_dir}") + return [], [] + + img_paths = glob_images(subset.image_dir, "*") + conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*") + img_paths = sorted(img_paths) + conditioning_img_paths = sorted(conditioning_img_paths) + print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + print(f"found directory {subset.conditioning_data_dir} contains {len(conditioning_img_paths)} image files") + + img_basenames = [os.path.basename(img) for img in img_paths] + conditioning_img_basenames = [os.path.basename(img) for img in conditioning_img_paths] + missing_imgs = [] + extra_imgs = [] + + for img in img_basenames: + if img not in conditioning_img_basenames: + missing_imgs.append(img) + for img in conditioning_img_basenames: + if img not in img_basenames: + extra_imgs.append(img) + + assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" + assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" + + + # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う + captions = [] + missing_captions = [] + for img_path in img_paths: + cap_for_img = read_caption(img_path, subset.caption_extension) + if cap_for_img is None: + print(f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}") + captions.append("") + missing_captions.append(img_path) + else: + captions.append(cap_for_img) + + self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 + + if missing_captions: + number_of_missing_captions = len(missing_captions) + number_of_missing_captions_to_show = 5 + remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show + + print( + f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。" + ) + for i, missing_caption in enumerate(missing_captions): + if i >= number_of_missing_captions_to_show: + print(missing_caption + f"... and {remaining_missing_captions} more") + break + print(missing_caption) + return img_paths, conditioning_img_paths, captions + + print("prepare images.") + num_train_images = 0 + for subset in subsets: + if subset.num_repeats < 1: + print( + f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" + ) + continue + + if subset in self.subsets: + print( + f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" + ) + continue + + img_paths, conditioning_img_paths, captions = load_controlnet_dir(subset) + if len(img_paths) < 1: + print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") + continue + + num_train_images += subset.num_repeats * len(img_paths) + + for img_path, cond_img_path, caption in zip(img_paths, conditioning_img_paths, captions): + info = ImageInfo(img_path, subset.num_repeats, caption, False, img_path) + setattr(info, "cond_img_path", cond_img_path) + self.register_image(info, subset) + + subset.img_count = len(img_paths) + self.subsets.append(subset) + + print(f"{num_train_images} train images with repeating.") + self.num_train_images = num_train_images + + self.conditioning_image_transforms = transforms.Compose( + [ + transforms.ToTensor(), + ] + ) + + def __getitem__(self, index): + bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index] + bucket_batch_size = self.buckets_indices[index].bucket_batch_size + image_index = self.buckets_indices[index].batch_index * bucket_batch_size + + loss_weights = [] + captions = [] + input_ids_list = [] + latents_list = [] + images = [] + conditioning_images = [] + + for image_key in bucket[image_index : image_index + bucket_batch_size]: + image_info = self.image_data[image_key] + subset = self.image_to_subset[image_key] + loss_weights.append(1.0) + + # image/latentsを処理する + if image_info.latents is not None: # cache_latents=Trueの場合 + latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped + image = None + elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 + latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5) + latents = torch.FloatTensor(latents) + image = None + else: + # 画像を読み込み、必要ならcropする + img = self.load_image(image_info.absolute_path) + im_h, im_w = img.shape[0:2] + + if self.enable_bucket: + img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size) + else: + im_h, im_w = img.shape[0:2] + assert ( + im_h == self.height and im_w == self.width + ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + + # augmentation + aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug) + if aug is not None: + img = aug(image=img)["image"] + + latents = None + image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + + images.append(image) + latents_list.append(latents) + + caption = self.process_caption(subset, image_info.caption) + if self.XTI_layers: + caption_layer = [] + for layer in self.XTI_layers: + token_strings_from = " ".join(self.token_strings) + token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + caption_ = caption.replace(token_strings_from, token_strings_to) + caption_layer.append(caption_) + captions.append(caption_layer) + else: + captions.append(caption) + if not self.token_padding_disabled: # this option might be omitted in future + if self.XTI_layers: + token_caption = self.get_input_ids(caption_layer) + else: + token_caption = self.get_input_ids(caption) + input_ids_list.append(token_caption) + + assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}" + + cond_img = self.load_image(image_info.cond_img_path) + if self.enable_bucket: + cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size) + cond_img = self.conditioning_image_transforms(cond_img) + conditioning_images.append(cond_img) + conditioning_images = torch.stack(conditioning_images) + + example = {} + example["loss_weights"] = torch.FloatTensor(loss_weights) + + if self.token_padding_disabled: + # padding=True means pad in the batch + example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids + else: + # batch processing seems to be good + example["input_ids"] = torch.stack(input_ids_list) + + if images[0] is not None: + images = torch.stack(images) + images = images.to(memory_format=torch.contiguous_format).float() + else: + images = None + example["images"] = images + + example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None + example["captions"] = captions + + if self.debug_dataset: + example["image_keys"] = bucket[image_index : image_index + self.batch_size] + + example["conditioning_images"] = conditioning_images.to(memory_format=torch.contiguous_format).float() + + return example + # behave as Dataset mock class DatasetGroup(torch.utils.data.ConcatDataset): def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]): @@ -1636,6 +1952,8 @@ def get_git_revision_hash() -> str: def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): replace_attentions_for_hypernetwork() # unet is not used currently, but it is here for future use + unet.enable_xformers_memory_efficient_attention() + return if mem_eff_attn: unet.set_attn_processor(FlashAttnProcessor()) elif xformers: diff --git a/train_controlnet.py b/train_controlnet.py new file mode 100644 index 00000000..7bcaf03a --- /dev/null +++ b/train_controlnet.py @@ -0,0 +1,594 @@ +import argparse +import gc +import math +import os +import random +import time +from multiprocessing import Value + +from tqdm import tqdm +import torch +from torch.nn.parallel import DistributedDataParallel as DDP +from accelerate.utils import set_seed +from diffusers import DDPMScheduler, ControlNetModel + +import library.model_util as model_util +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.huggingface_util as huggingface_util +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + apply_snr_weight, + pyramid_noise_like, + apply_noise_offset, +) +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + download_controlnet_from_original_ckpt, +) + + +# TODO 他のスクリプトと共通化する +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = { + "loss/current": current_loss, + "loss/average": avr_loss, + "lr": lr_scheduler.get_last_lr()[0], + } + + if args.optimizer_type.lower().startswith("DAdapt".lower()): + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[-1].param_groups[0]["d"] + * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + ) + + return logs + + +def train(args): + session_id = random.randint(0, 2**32) + training_started_at = time.time() + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + + cache_latents = args.cache_latents + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + tokenizer = train_util.load_tokenizer(args) + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if use_user_config: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension, + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint( + blueprint.dataset_group + ) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = ( + train_dataset_group if args.max_data_loader_n_workers == 0 else None + ) + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + # acceleratorを準備する + print("prepare accelerator") + accelerator, unwrap_model = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + text_encoder, vae, unet, _ = train_util.load_target_model( + args, weight_dtype, accelerator + ) + if args.controlnet_model_name_or_path: + if os.path.isfile(args.controlnet_model_name_or_path): + controlnet = download_controlnet_from_original_ckpt( + args.controlnet_model_name_or_path + ) + else: + controlnet = ControlNetModel.from_pretrained( + args.controlnet_model_name_or_path + ) + else: + controlnet = ControlNetModel.from_unet(unet) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents( + vae, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process, + ) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + accelerator.wait_for_everyone() + + if args.gradient_checkpointing: + controlnet.enable_gradient_checkpointing() + + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") + + trainable_params = controlnet.parameters() + + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer( + args, trainable_params + ) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min( + args.max_data_loader_n_workers, os.cpu_count() - 1 + ) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) + / accelerator.num_processes + / args.gradient_accumulation_steps + ) + if is_main_process: + print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix( + args, optimizer, accelerator.num_processes + ) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + print("enable full fp16 training.") + controlnet.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + unet.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.to(accelerator.device) + text_encoder.to(accelerator.device) + + # transform DDP after prepare + controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet + + controlnet.train() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = ( + math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + ) + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + + if is_main_process: + print("running training / 学習開始") + print( + f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}" + ) + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print( + f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}" + ) + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm( + range(args.max_train_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="steps", + ) + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + ) + if accelerator.is_main_process: + accelerator.init_trackers( + "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name + ) + + loss_list = [] + loss_total = 0.0 + del train_dataset_group + + # function for saving/removing + def save_model(ckpt_name, model, steps, epoch_no, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"\nsaving checkpoint: {ckpt_file}") + + state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) + + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(ckpt_file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, ckpt_file) + else: + torch.save(state_dict, ckpt_file) + + if args.huggingface_repo_id is not None: + huggingface_util.upload( + args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload + ) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # training loop + for epoch in range(num_train_epochs): + if is_main_process: + print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(controlnet): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode( + batch["images"].to(dtype=weight_dtype) + ).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.gethidden_states( + args, input_ids, tokenizer, text_encoder, weight_dtype + ) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + noise = apply_noise_offset( + latents, noise, args.noise_offset, args.adaptive_noise_scale + ) + elif args.multires_noise_iterations: + noise = pyramid_noise_like( + noise, + latents.device, + args.multires_noise_iterations, + args.multires_noise_discount, + ) + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (b_size,), + device=latents.device, + ) + timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + + with accelerator.autocast(): + down_block_res_samples, mid_block_res_sample = controlnet( + noisy_latents, + timesteps, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_image, + return_dict=False, + ) + + # Predict the noise residual + noise_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states, + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) + for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to( + dtype=weight_dtype + ), + ).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss( + noise_pred.float(), target.float(), reduction="none" + ) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight( + loss, timesteps, noise_scheduler, args.min_snr_gamma + ) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + tokenizer, + text_encoder, + unet, + ) + + # 指定ステップごとにモデルを保存 + if ( + args.save_every_n_steps is not None + and global_step % args.save_every_n_steps == 0 + ): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name( + args, "." + args.save_model_as, global_step + ) + save_model( + ckpt_name, unwrap_model(controlnet), global_step, epoch + ) + + if args.save_state: + train_util.save_and_remove_state_stepwise( + args, accelerator, global_step + ) + + remove_step_no = train_util.get_remove_step_no( + args, global_step + ) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name( + args, "." + args.save_model_as, remove_step_no + ) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.logging_dir is not None: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and ( + epoch + 1 + ) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name( + args, "." + args.save_model_as, epoch + 1 + ) + save_model(ckpt_name, unwrap_model(controlnet), global_step, epoch + 1) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name( + args, "." + args.save_model_as, remove_epoch_no + ) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end( + args, accelerator, epoch + 1 + ) + + train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + tokenizer, + text_encoder, + unet, + ) + + # end of epoch + if is_main_process: + controlnet = unwrap_model(controlnet) + + accelerator.end_training() + + if is_main_process and args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model( + ckpt_name, controlnet, global_step, num_train_epochs, force_sync_upload=True + ) + + print("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="controlnet model name or path / controlnetのモデル名またはパス", + ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args)