add controlnet training

This commit is contained in:
ddPn08
2023-05-31 14:13:15 +09:00
parent 4f8ce00477
commit 62d00b4520
4 changed files with 1075 additions and 10 deletions

View File

@@ -33,8 +33,10 @@ from . import train_util
from .train_util import (
DreamBoothSubset,
FineTuningSubset,
ControlNetSubset,
DreamBoothDataset,
FineTuningDataset,
ControlNetDataset,
DatasetGroup,
)
@@ -70,6 +72,11 @@ class DreamBoothSubsetParams(BaseSubsetParams):
class FineTuningSubsetParams(BaseSubsetParams):
metadata_file: Optional[str] = None
@dataclass
class ControlNetSubsetParams(BaseSubsetParams):
conditioning_data_dir: str = None
caption_extension: str = ".caption"
@dataclass
class BaseDatasetParams:
tokenizer: CLIPTokenizer = None
@@ -96,6 +103,15 @@ class FineTuningDatasetParams(BaseDatasetParams):
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
@dataclass
class ControlNetDatasetParams(BaseDatasetParams):
batch_size: int = 1
enable_bucket: bool = False
min_bucket_reso: int = 256
max_bucket_reso: int = 1024
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
@dataclass
class SubsetBlueprint:
params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
@@ -103,6 +119,7 @@ class SubsetBlueprint:
@dataclass
class DatasetBlueprint:
is_dreambooth: bool
is_controlnet: bool
params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
subsets: Sequence[SubsetBlueprint]
@@ -163,6 +180,13 @@ class ConfigSanitizer:
Required("metadata_file"): str,
"image_dir": str,
}
CN_SUBSET_ASCENDABLE_SCHEMA = {
"caption_extension": str,
}
CN_SUBSET_DISTINCT_SCHEMA = {
Required("image_dir"): str,
Required("conditioning_data_dir"): str,
}
# datasets schema
DATASET_ASCENDABLE_SCHEMA = {
@@ -192,8 +216,8 @@ class ConfigSanitizer:
"dataset_repeats": "num_repeats",
}
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
assert support_dreambooth or support_finetuning, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
assert support_dreambooth or support_finetuning or support_controlnet, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
self.db_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA,
@@ -208,6 +232,13 @@ class ConfigSanitizer:
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
)
self.cn_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA,
self.CN_SUBSET_DISTINCT_SCHEMA,
self.CN_SUBSET_ASCENDABLE_SCHEMA,
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
)
self.db_dataset_schema = self.__merge_dict(
self.DATASET_ASCENDABLE_SCHEMA,
self.SUBSET_ASCENDABLE_SCHEMA,
@@ -223,13 +254,23 @@ class ConfigSanitizer:
{"subsets": [self.ft_subset_schema]},
)
if support_dreambooth and support_finetuning:
self.cn_dataset_schema = self.__merge_dict(
self.DATASET_ASCENDABLE_SCHEMA,
self.SUBSET_ASCENDABLE_SCHEMA,
self.CN_SUBSET_ASCENDABLE_SCHEMA,
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
{"subsets": [self.cn_subset_schema]},
)
if support_dreambooth and support_finetuning and support_controlnet:
def validate_flex_dataset(dataset_config: dict):
subsets_config = dataset_config.get("subsets", [])
if all(["conditioning_data_dir" in subset for subset in subsets_config]):
return Schema(self.cn_dataset_schema)(dataset_config)
# check dataset meets FT style
# NOTE: all FT subsets should have "metadata_file"
if all(["metadata_file" in subset for subset in subsets_config]):
elif all(["metadata_file" in subset for subset in subsets_config]):
return Schema(self.ft_dataset_schema)(dataset_config)
# check dataset meets DB style
# NOTE: all DB subsets should have no "metadata_file"
@@ -241,13 +282,16 @@ class ConfigSanitizer:
self.dataset_schema = validate_flex_dataset
elif support_dreambooth:
self.dataset_schema = self.db_dataset_schema
else:
elif support_finetuning:
self.dataset_schema = self.ft_dataset_schema
elif support_controlnet:
self.dataset_schema = self.cn_dataset_schema
self.general_schema = self.__merge_dict(
self.DATASET_ASCENDABLE_SCHEMA,
self.SUBSET_ASCENDABLE_SCHEMA,
self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
)
@@ -318,7 +362,11 @@ class BlueprintGenerator:
# NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
subsets = dataset_config.get("subsets", [])
is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
if is_dreambooth:
is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
if is_controlnet:
subset_params_klass = ControlNetSubsetParams
dataset_params_klass = ControlNetDatasetParams
elif is_dreambooth:
subset_params_klass = DreamBoothSubsetParams
dataset_params_klass = DreamBoothDatasetParams
else:
@@ -333,7 +381,7 @@ class BlueprintGenerator:
params = self.generate_params_by_fallbacks(dataset_params_klass,
[dataset_config, general_config, argparse_config, runtime_params])
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, params, subset_blueprints))
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
@@ -361,10 +409,13 @@ class BlueprintGenerator:
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = []
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
for dataset_blueprint in dataset_group_blueprint.datasets:
if dataset_blueprint.is_dreambooth:
if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
else:
@@ -379,6 +430,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
info = ""
for i, dataset in enumerate(datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(f"""\
[Dataset {i}]
batch_size: {dataset.batch_size}
@@ -421,7 +473,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""), " ")
else:
elif not is_controlnet:
info += indent(dedent(f"""\
metadata_file: {subset.metadata_file}
\n"""), " ")
@@ -479,6 +531,31 @@ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str]
return subsets_config
def generate_controlnet_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"):
def generate(base_dir: Optional[str]):
if base_dir is None:
return []
base_dir: Path = Path(base_dir)
if not base_dir.is_dir():
return []
subsets_config = []
for subdir in base_dir.iterdir():
if not subdir.is_dir():
continue
subset_config = {"image_dir": str(subdir), "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1}
subsets_config.append(subset_config)
return subsets_config
subsets_config = []
subsets_config += generate(train_data_dir, False)
return subsets_config
def load_user_config(file: str) -> dict:
file: Path = Path(file)
if not file.is_file():

View File

@@ -732,6 +732,82 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
return new_state_dict
def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
unet_conversion_map = [
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"),
("input_blocks.0.0.bias", "conv_in.bias"),
("middle_block_out.0.weight", "controlnet_mid_block.weight"),
("middle_block_out.0.bias", "controlnet_mid_block.bias"),
]
unet_conversion_map_resnet = [
("in_layers.0", "norm1"),
("in_layers.2", "conv1"),
("out_layers.0", "norm2"),
("out_layers.3", "conv2"),
("emb_layers.1", "time_emb_proj"),
("skip_connection", "conv_shortcut"),
]
unet_conversion_map_layer = []
for i in range(4):
for j in range(2):
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
if i < 3:
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
controlnet_cond_embedding_names = (
["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
)
for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
sd_prefix = f"input_hint_block.{i*2}."
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
for i in range(12):
hf_prefix = f"controlnet_down_blocks.{i}."
sd_prefix = f"zero_convs.{i}.0."
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
mapping = {k: k for k in controlnet_state_dict.keys()}
for sd_name, diffusers_name in unet_conversion_map:
mapping[diffusers_name] = sd_name
for k, v in mapping.items():
if "resnets" in k:
for sd_part, diffusers_part in unet_conversion_map_resnet:
v = v.replace(diffusers_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
for sd_part, diffusers_part in unet_conversion_map_layer:
v = v.replace(diffusers_part, sd_part)
mapping[k] = v
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
# ================#
# VAE Conversion #
# ================#

View File

@@ -403,6 +403,54 @@ class FineTuningSubset(BaseSubset):
return self.metadata_file == other.metadata_file
class ControlNetSubset(BaseSubset):
def __init__(
self,
image_dir: str,
conditioning_data_dir: str,
caption_extension: str,
num_repeats,
shuffle_caption,
keep_tokens,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
super().__init__(
image_dir,
num_repeats,
shuffle_caption,
keep_tokens,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
)
self.conditioning_data_dir = conditioning_data_dir
self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
def __eq__(self, other) -> bool:
if not isinstance(other, ControlNetSubset):
return NotImplemented
return self.image_dir == other.image_dir and self.conditioning_data_dir == other.conditioning_data_dir
class BaseDataset(torch.utils.data.Dataset):
def __init__(
self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool
@@ -1387,6 +1435,274 @@ class FineTuningDataset(BaseDataset):
return npz_file_norm, npz_file_flip
class ControlNetDataset(BaseDataset):
def __init__(
self,
subsets: Sequence[ControlNetSubset],
batch_size: int,
tokenizer,
max_token_length,
resolution,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
self.conditioning_image_data: Dict[str, ImageInfo] = {}
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
self.latents_cache = None
self.num_reg_images = 0
self.enable_bucket = enable_bucket
if self.enable_bucket:
assert (
min(resolution) >= min_bucket_reso
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
assert (
max(resolution) <= max_bucket_reso
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
self.bucket_no_upscale = bucket_no_upscale
else:
self.min_bucket_reso = None
self.max_bucket_reso = None
self.bucket_reso_steps = None # この情報は使われない
self.bucket_no_upscale = False
def read_caption(img_path, caption_extension):
# captionの候補ファイル名を作る
base_name = os.path.splitext(img_path)[0]
base_name_face_det = base_name
tokens = base_name.split("_")
if len(tokens) >= 5:
base_name_face_det = "_".join(tokens[:-4])
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
caption = None
for cap_path in cap_paths:
if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding="utf-8") as f:
try:
lines = f.readlines()
except UnicodeDecodeError as e:
print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
raise e
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
caption = lines[0].strip()
break
return caption
def load_controlnet_dir(subset: ControlNetSubset):
if not os.path.isdir(subset.image_dir):
print(f"not directory: {subset.image_dir}")
return [], []
if not os.path.isdir(subset.conditioning_data_dir):
print(f"not directory: {subset.conditioning_data_dir}")
return [], []
img_paths = glob_images(subset.image_dir, "*")
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
img_paths = sorted(img_paths)
conditioning_img_paths = sorted(conditioning_img_paths)
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
print(f"found directory {subset.conditioning_data_dir} contains {len(conditioning_img_paths)} image files")
img_basenames = [os.path.basename(img) for img in img_paths]
conditioning_img_basenames = [os.path.basename(img) for img in conditioning_img_paths]
missing_imgs = []
extra_imgs = []
for img in img_basenames:
if img not in conditioning_img_basenames:
missing_imgs.append(img)
for img in conditioning_img_basenames:
if img not in img_basenames:
extra_imgs.append(img)
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension)
if cap_for_img is None:
print(f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}")
captions.append("")
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
if missing_captions:
number_of_missing_captions = len(missing_captions)
number_of_missing_captions_to_show = 5
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show
print(
f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。"
)
for i, missing_caption in enumerate(missing_captions):
if i >= number_of_missing_captions_to_show:
print(missing_caption + f"... and {remaining_missing_captions} more")
break
print(missing_caption)
return img_paths, conditioning_img_paths, captions
print("prepare images.")
num_train_images = 0
for subset in subsets:
if subset.num_repeats < 1:
print(
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
)
continue
if subset in self.subsets:
print(
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します"
)
continue
img_paths, conditioning_img_paths, captions = load_controlnet_dir(subset)
if len(img_paths) < 1:
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
continue
num_train_images += subset.num_repeats * len(img_paths)
for img_path, cond_img_path, caption in zip(img_paths, conditioning_img_paths, captions):
info = ImageInfo(img_path, subset.num_repeats, caption, False, img_path)
setattr(info, "cond_img_path", cond_img_path)
self.register_image(info, subset)
subset.img_count = len(img_paths)
self.subsets.append(subset)
print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images
self.conditioning_image_transforms = transforms.Compose(
[
transforms.ToTensor(),
]
)
def __getitem__(self, index):
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
loss_weights = []
captions = []
input_ids_list = []
latents_list = []
images = []
conditioning_images = []
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
loss_weights.append(1.0)
# image/latentsを処理する
if image_info.latents is not None: # cache_latents=Trueの場合
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
latents = torch.FloatTensor(latents)
image = None
else:
# 画像を読み込み、必要ならcropする
img = self.load_image(image_info.absolute_path)
im_h, im_w = img.shape[0:2]
if self.enable_bucket:
img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
else:
im_h, im_w = img.shape[0:2]
assert (
im_h == self.height and im_w == self.width
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# augmentation
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
if aug is not None:
img = aug(image=img)["image"]
latents = None
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
images.append(image)
latents_list.append(latents)
caption = self.process_caption(subset, image_info.caption)
if self.XTI_layers:
caption_layer = []
for layer in self.XTI_layers:
token_strings_from = " ".join(self.token_strings)
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
caption_ = caption.replace(token_strings_from, token_strings_to)
caption_layer.append(caption_)
captions.append(caption_layer)
else:
captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future
if self.XTI_layers:
token_caption = self.get_input_ids(caption_layer)
else:
token_caption = self.get_input_ids(caption)
input_ids_list.append(token_caption)
assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}"
cond_img = self.load_image(image_info.cond_img_path)
if self.enable_bucket:
cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size)
cond_img = self.conditioning_image_transforms(cond_img)
conditioning_images.append(cond_img)
conditioning_images = torch.stack(conditioning_images)
example = {}
example["loss_weights"] = torch.FloatTensor(loss_weights)
if self.token_padding_disabled:
# padding=True means pad in the batch
example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
else:
# batch processing seems to be good
example["input_ids"] = torch.stack(input_ids_list)
if images[0] is not None:
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
else:
images = None
example["images"] = images
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
example["captions"] = captions
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
example["conditioning_images"] = conditioning_images.to(memory_format=torch.contiguous_format).float()
return example
# behave as Dataset mock
class DatasetGroup(torch.utils.data.ConcatDataset):
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
@@ -1636,6 +1952,8 @@ def get_git_revision_hash() -> str:
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
replace_attentions_for_hypernetwork()
# unet is not used currently, but it is here for future use
unet.enable_xformers_memory_efficient_attention()
return
if mem_eff_attn:
unet.set_attn_processor(FlashAttnProcessor())
elif xformers: