mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add controlnet training
This commit is contained in:
@@ -33,8 +33,10 @@ from . import train_util
|
||||
from .train_util import (
|
||||
DreamBoothSubset,
|
||||
FineTuningSubset,
|
||||
ControlNetSubset,
|
||||
DreamBoothDataset,
|
||||
FineTuningDataset,
|
||||
ControlNetDataset,
|
||||
DatasetGroup,
|
||||
)
|
||||
|
||||
@@ -70,6 +72,11 @@ class DreamBoothSubsetParams(BaseSubsetParams):
|
||||
class FineTuningSubsetParams(BaseSubsetParams):
|
||||
metadata_file: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class ControlNetSubsetParams(BaseSubsetParams):
|
||||
conditioning_data_dir: str = None
|
||||
caption_extension: str = ".caption"
|
||||
|
||||
@dataclass
|
||||
class BaseDatasetParams:
|
||||
tokenizer: CLIPTokenizer = None
|
||||
@@ -96,6 +103,15 @@ class FineTuningDatasetParams(BaseDatasetParams):
|
||||
bucket_reso_steps: int = 64
|
||||
bucket_no_upscale: bool = False
|
||||
|
||||
@dataclass
|
||||
class ControlNetDatasetParams(BaseDatasetParams):
|
||||
batch_size: int = 1
|
||||
enable_bucket: bool = False
|
||||
min_bucket_reso: int = 256
|
||||
max_bucket_reso: int = 1024
|
||||
bucket_reso_steps: int = 64
|
||||
bucket_no_upscale: bool = False
|
||||
|
||||
@dataclass
|
||||
class SubsetBlueprint:
|
||||
params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
|
||||
@@ -103,6 +119,7 @@ class SubsetBlueprint:
|
||||
@dataclass
|
||||
class DatasetBlueprint:
|
||||
is_dreambooth: bool
|
||||
is_controlnet: bool
|
||||
params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
|
||||
subsets: Sequence[SubsetBlueprint]
|
||||
|
||||
@@ -163,6 +180,13 @@ class ConfigSanitizer:
|
||||
Required("metadata_file"): str,
|
||||
"image_dir": str,
|
||||
}
|
||||
CN_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
"caption_extension": str,
|
||||
}
|
||||
CN_SUBSET_DISTINCT_SCHEMA = {
|
||||
Required("image_dir"): str,
|
||||
Required("conditioning_data_dir"): str,
|
||||
}
|
||||
|
||||
# datasets schema
|
||||
DATASET_ASCENDABLE_SCHEMA = {
|
||||
@@ -192,8 +216,8 @@ class ConfigSanitizer:
|
||||
"dataset_repeats": "num_repeats",
|
||||
}
|
||||
|
||||
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
|
||||
assert support_dreambooth or support_finetuning, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
|
||||
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
|
||||
assert support_dreambooth or support_finetuning or support_controlnet, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
|
||||
|
||||
self.db_subset_schema = self.__merge_dict(
|
||||
self.SUBSET_ASCENDABLE_SCHEMA,
|
||||
@@ -208,6 +232,13 @@ class ConfigSanitizer:
|
||||
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
||||
)
|
||||
|
||||
self.cn_subset_schema = self.__merge_dict(
|
||||
self.SUBSET_ASCENDABLE_SCHEMA,
|
||||
self.CN_SUBSET_DISTINCT_SCHEMA,
|
||||
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
||||
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
||||
)
|
||||
|
||||
self.db_dataset_schema = self.__merge_dict(
|
||||
self.DATASET_ASCENDABLE_SCHEMA,
|
||||
self.SUBSET_ASCENDABLE_SCHEMA,
|
||||
@@ -223,13 +254,23 @@ class ConfigSanitizer:
|
||||
{"subsets": [self.ft_subset_schema]},
|
||||
)
|
||||
|
||||
if support_dreambooth and support_finetuning:
|
||||
self.cn_dataset_schema = self.__merge_dict(
|
||||
self.DATASET_ASCENDABLE_SCHEMA,
|
||||
self.SUBSET_ASCENDABLE_SCHEMA,
|
||||
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
||||
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
||||
{"subsets": [self.cn_subset_schema]},
|
||||
)
|
||||
|
||||
if support_dreambooth and support_finetuning and support_controlnet:
|
||||
def validate_flex_dataset(dataset_config: dict):
|
||||
subsets_config = dataset_config.get("subsets", [])
|
||||
|
||||
if all(["conditioning_data_dir" in subset for subset in subsets_config]):
|
||||
return Schema(self.cn_dataset_schema)(dataset_config)
|
||||
# check dataset meets FT style
|
||||
# NOTE: all FT subsets should have "metadata_file"
|
||||
if all(["metadata_file" in subset for subset in subsets_config]):
|
||||
elif all(["metadata_file" in subset for subset in subsets_config]):
|
||||
return Schema(self.ft_dataset_schema)(dataset_config)
|
||||
# check dataset meets DB style
|
||||
# NOTE: all DB subsets should have no "metadata_file"
|
||||
@@ -241,13 +282,16 @@ class ConfigSanitizer:
|
||||
self.dataset_schema = validate_flex_dataset
|
||||
elif support_dreambooth:
|
||||
self.dataset_schema = self.db_dataset_schema
|
||||
else:
|
||||
elif support_finetuning:
|
||||
self.dataset_schema = self.ft_dataset_schema
|
||||
elif support_controlnet:
|
||||
self.dataset_schema = self.cn_dataset_schema
|
||||
|
||||
self.general_schema = self.__merge_dict(
|
||||
self.DATASET_ASCENDABLE_SCHEMA,
|
||||
self.SUBSET_ASCENDABLE_SCHEMA,
|
||||
self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
|
||||
self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
|
||||
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
||||
)
|
||||
|
||||
@@ -318,7 +362,11 @@ class BlueprintGenerator:
|
||||
# NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
|
||||
subsets = dataset_config.get("subsets", [])
|
||||
is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
|
||||
if is_dreambooth:
|
||||
is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
|
||||
if is_controlnet:
|
||||
subset_params_klass = ControlNetSubsetParams
|
||||
dataset_params_klass = ControlNetDatasetParams
|
||||
elif is_dreambooth:
|
||||
subset_params_klass = DreamBoothSubsetParams
|
||||
dataset_params_klass = DreamBoothDatasetParams
|
||||
else:
|
||||
@@ -333,7 +381,7 @@ class BlueprintGenerator:
|
||||
|
||||
params = self.generate_params_by_fallbacks(dataset_params_klass,
|
||||
[dataset_config, general_config, argparse_config, runtime_params])
|
||||
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, params, subset_blueprints))
|
||||
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
|
||||
|
||||
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
|
||||
|
||||
@@ -361,10 +409,13 @@ class BlueprintGenerator:
|
||||
|
||||
|
||||
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
|
||||
datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = []
|
||||
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
||||
|
||||
for dataset_blueprint in dataset_group_blueprint.datasets:
|
||||
if dataset_blueprint.is_dreambooth:
|
||||
if dataset_blueprint.is_controlnet:
|
||||
subset_klass = ControlNetSubset
|
||||
dataset_klass = ControlNetDataset
|
||||
elif dataset_blueprint.is_dreambooth:
|
||||
subset_klass = DreamBoothSubset
|
||||
dataset_klass = DreamBoothDataset
|
||||
else:
|
||||
@@ -379,6 +430,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
info = ""
|
||||
for i, dataset in enumerate(datasets):
|
||||
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
||||
is_controlnet = isinstance(dataset, ControlNetDataset)
|
||||
info += dedent(f"""\
|
||||
[Dataset {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
@@ -421,7 +473,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
class_tokens: {subset.class_tokens}
|
||||
caption_extension: {subset.caption_extension}
|
||||
\n"""), " ")
|
||||
else:
|
||||
elif not is_controlnet:
|
||||
info += indent(dedent(f"""\
|
||||
metadata_file: {subset.metadata_file}
|
||||
\n"""), " ")
|
||||
@@ -479,6 +531,31 @@ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str]
|
||||
return subsets_config
|
||||
|
||||
|
||||
def generate_controlnet_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"):
|
||||
def generate(base_dir: Optional[str]):
|
||||
if base_dir is None:
|
||||
return []
|
||||
|
||||
base_dir: Path = Path(base_dir)
|
||||
if not base_dir.is_dir():
|
||||
return []
|
||||
|
||||
subsets_config = []
|
||||
for subdir in base_dir.iterdir():
|
||||
if not subdir.is_dir():
|
||||
continue
|
||||
|
||||
subset_config = {"image_dir": str(subdir), "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1}
|
||||
subsets_config.append(subset_config)
|
||||
|
||||
return subsets_config
|
||||
|
||||
subsets_config = []
|
||||
subsets_config += generate(train_data_dir, False)
|
||||
|
||||
return subsets_config
|
||||
|
||||
|
||||
def load_user_config(file: str) -> dict:
|
||||
file: Path = Path(file)
|
||||
if not file.is_file():
|
||||
|
||||
@@ -732,6 +732,82 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
|
||||
unet_conversion_map = [
|
||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
||||
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
||||
("input_blocks.0.0.weight", "conv_in.weight"),
|
||||
("input_blocks.0.0.bias", "conv_in.bias"),
|
||||
("middle_block_out.0.weight", "controlnet_mid_block.weight"),
|
||||
("middle_block_out.0.bias", "controlnet_mid_block.bias"),
|
||||
]
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
("in_layers.0", "norm1"),
|
||||
("in_layers.2", "conv1"),
|
||||
("out_layers.0", "norm2"),
|
||||
("out_layers.3", "conv2"),
|
||||
("emb_layers.1", "time_emb_proj"),
|
||||
("skip_connection", "conv_shortcut"),
|
||||
]
|
||||
|
||||
unet_conversion_map_layer = []
|
||||
for i in range(4):
|
||||
for j in range(2):
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
controlnet_cond_embedding_names = (
|
||||
["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
|
||||
)
|
||||
for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
|
||||
hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
|
||||
sd_prefix = f"input_hint_block.{i*2}."
|
||||
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
||||
|
||||
for i in range(12):
|
||||
hf_prefix = f"controlnet_down_blocks.{i}."
|
||||
sd_prefix = f"zero_convs.{i}.0."
|
||||
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
||||
|
||||
mapping = {k: k for k in controlnet_state_dict.keys()}
|
||||
for sd_name, diffusers_name in unet_conversion_map:
|
||||
mapping[diffusers_name] = sd_name
|
||||
for k, v in mapping.items():
|
||||
if "resnets" in k:
|
||||
for sd_part, diffusers_part in unet_conversion_map_resnet:
|
||||
v = v.replace(diffusers_part, sd_part)
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
for sd_part, diffusers_part in unet_conversion_map_layer:
|
||||
v = v.replace(diffusers_part, sd_part)
|
||||
mapping[k] = v
|
||||
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# ================#
|
||||
# VAE Conversion #
|
||||
# ================#
|
||||
|
||||
@@ -403,6 +403,54 @@ class FineTuningSubset(BaseSubset):
|
||||
return self.metadata_file == other.metadata_file
|
||||
|
||||
|
||||
class ControlNetSubset(BaseSubset):
|
||||
def __init__(
|
||||
self,
|
||||
image_dir: str,
|
||||
conditioning_data_dir: str,
|
||||
caption_extension: str,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
keep_tokens,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
random_crop,
|
||||
caption_dropout_rate,
|
||||
caption_dropout_every_n_epochs,
|
||||
caption_tag_dropout_rate,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
super().__init__(
|
||||
image_dir,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
keep_tokens,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
random_crop,
|
||||
caption_dropout_rate,
|
||||
caption_dropout_every_n_epochs,
|
||||
caption_tag_dropout_rate,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
)
|
||||
|
||||
self.conditioning_data_dir = conditioning_data_dir
|
||||
self.caption_extension = caption_extension
|
||||
if self.caption_extension and not self.caption_extension.startswith("."):
|
||||
self.caption_extension = "." + self.caption_extension
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, ControlNetSubset):
|
||||
return NotImplemented
|
||||
return self.image_dir == other.image_dir and self.conditioning_data_dir == other.conditioning_data_dir
|
||||
|
||||
|
||||
class BaseDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool
|
||||
@@ -1387,6 +1435,274 @@ class FineTuningDataset(BaseDataset):
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
|
||||
class ControlNetDataset(BaseDataset):
|
||||
def __init__(
|
||||
self,
|
||||
subsets: Sequence[ControlNetSubset],
|
||||
batch_size: int,
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
enable_bucket: bool,
|
||||
min_bucket_reso: int,
|
||||
max_bucket_reso: int,
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
self.conditioning_image_data: Dict[str, ImageInfo] = {}
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.latents_cache = None
|
||||
|
||||
self.num_reg_images = 0
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
assert (
|
||||
min(resolution) >= min_bucket_reso
|
||||
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
|
||||
assert (
|
||||
max(resolution) <= max_bucket_reso
|
||||
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
else:
|
||||
self.min_bucket_reso = None
|
||||
self.max_bucket_reso = None
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
def read_caption(img_path, caption_extension):
|
||||
# captionの候補ファイル名を作る
|
||||
base_name = os.path.splitext(img_path)[0]
|
||||
base_name_face_det = base_name
|
||||
tokens = base_name.split("_")
|
||||
if len(tokens) >= 5:
|
||||
base_name_face_det = "_".join(tokens[:-4])
|
||||
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
|
||||
|
||||
caption = None
|
||||
for cap_path in cap_paths:
|
||||
if os.path.isfile(cap_path):
|
||||
with open(cap_path, "rt", encoding="utf-8") as f:
|
||||
try:
|
||||
lines = f.readlines()
|
||||
except UnicodeDecodeError as e:
|
||||
print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
||||
raise e
|
||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||
caption = lines[0].strip()
|
||||
break
|
||||
return caption
|
||||
|
||||
def load_controlnet_dir(subset: ControlNetSubset):
|
||||
if not os.path.isdir(subset.image_dir):
|
||||
print(f"not directory: {subset.image_dir}")
|
||||
return [], []
|
||||
if not os.path.isdir(subset.conditioning_data_dir):
|
||||
print(f"not directory: {subset.conditioning_data_dir}")
|
||||
return [], []
|
||||
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
|
||||
img_paths = sorted(img_paths)
|
||||
conditioning_img_paths = sorted(conditioning_img_paths)
|
||||
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
print(f"found directory {subset.conditioning_data_dir} contains {len(conditioning_img_paths)} image files")
|
||||
|
||||
img_basenames = [os.path.basename(img) for img in img_paths]
|
||||
conditioning_img_basenames = [os.path.basename(img) for img in conditioning_img_paths]
|
||||
missing_imgs = []
|
||||
extra_imgs = []
|
||||
|
||||
for img in img_basenames:
|
||||
if img not in conditioning_img_basenames:
|
||||
missing_imgs.append(img)
|
||||
for img in conditioning_img_basenames:
|
||||
if img not in img_basenames:
|
||||
extra_imgs.append(img)
|
||||
|
||||
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
|
||||
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
|
||||
|
||||
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
missing_captions = []
|
||||
for img_path in img_paths:
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension)
|
||||
if cap_for_img is None:
|
||||
print(f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}")
|
||||
captions.append("")
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
captions.append(cap_for_img)
|
||||
|
||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
||||
|
||||
if missing_captions:
|
||||
number_of_missing_captions = len(missing_captions)
|
||||
number_of_missing_captions_to_show = 5
|
||||
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show
|
||||
|
||||
print(
|
||||
f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。"
|
||||
)
|
||||
for i, missing_caption in enumerate(missing_captions):
|
||||
if i >= number_of_missing_captions_to_show:
|
||||
print(missing_caption + f"... and {remaining_missing_captions} more")
|
||||
break
|
||||
print(missing_caption)
|
||||
return img_paths, conditioning_img_paths, captions
|
||||
|
||||
print("prepare images.")
|
||||
num_train_images = 0
|
||||
for subset in subsets:
|
||||
if subset.num_repeats < 1:
|
||||
print(
|
||||
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
|
||||
)
|
||||
continue
|
||||
|
||||
if subset in self.subsets:
|
||||
print(
|
||||
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します"
|
||||
)
|
||||
continue
|
||||
|
||||
img_paths, conditioning_img_paths, captions = load_controlnet_dir(subset)
|
||||
if len(img_paths) < 1:
|
||||
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
|
||||
continue
|
||||
|
||||
num_train_images += subset.num_repeats * len(img_paths)
|
||||
|
||||
for img_path, cond_img_path, caption in zip(img_paths, conditioning_img_paths, captions):
|
||||
info = ImageInfo(img_path, subset.num_repeats, caption, False, img_path)
|
||||
setattr(info, "cond_img_path", cond_img_path)
|
||||
self.register_image(info, subset)
|
||||
|
||||
subset.img_count = len(img_paths)
|
||||
self.subsets.append(subset)
|
||||
|
||||
print(f"{num_train_images} train images with repeating.")
|
||||
self.num_train_images = num_train_images
|
||||
|
||||
self.conditioning_image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
||||
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
||||
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
||||
|
||||
loss_weights = []
|
||||
captions = []
|
||||
input_ids_list = []
|
||||
latents_list = []
|
||||
images = []
|
||||
conditioning_images = []
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
loss_weights.append(1.0)
|
||||
|
||||
# image/latentsを処理する
|
||||
if image_info.latents is not None: # cache_latents=Trueの場合
|
||||
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
|
||||
image = None
|
||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
|
||||
latents = torch.FloatTensor(latents)
|
||||
image = None
|
||||
else:
|
||||
# 画像を読み込み、必要ならcropする
|
||||
img = self.load_image(image_info.absolute_path)
|
||||
im_h, im_w = img.shape[0:2]
|
||||
|
||||
if self.enable_bucket:
|
||||
img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
|
||||
else:
|
||||
im_h, im_w = img.shape[0:2]
|
||||
assert (
|
||||
im_h == self.height and im_w == self.width
|
||||
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
|
||||
# augmentation
|
||||
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
|
||||
if aug is not None:
|
||||
img = aug(image=img)["image"]
|
||||
|
||||
latents = None
|
||||
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
if self.XTI_layers:
|
||||
caption_layer = []
|
||||
for layer in self.XTI_layers:
|
||||
token_strings_from = " ".join(self.token_strings)
|
||||
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||
caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||
caption_layer.append(caption_)
|
||||
captions.append(caption_layer)
|
||||
else:
|
||||
captions.append(caption)
|
||||
if not self.token_padding_disabled: # this option might be omitted in future
|
||||
if self.XTI_layers:
|
||||
token_caption = self.get_input_ids(caption_layer)
|
||||
else:
|
||||
token_caption = self.get_input_ids(caption)
|
||||
input_ids_list.append(token_caption)
|
||||
|
||||
assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}"
|
||||
|
||||
cond_img = self.load_image(image_info.cond_img_path)
|
||||
if self.enable_bucket:
|
||||
cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size)
|
||||
cond_img = self.conditioning_image_transforms(cond_img)
|
||||
conditioning_images.append(cond_img)
|
||||
conditioning_images = torch.stack(conditioning_images)
|
||||
|
||||
example = {}
|
||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||
|
||||
if self.token_padding_disabled:
|
||||
# padding=True means pad in the batch
|
||||
example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||
else:
|
||||
# batch processing seems to be good
|
||||
example["input_ids"] = torch.stack(input_ids_list)
|
||||
|
||||
if images[0] is not None:
|
||||
images = torch.stack(images)
|
||||
images = images.to(memory_format=torch.contiguous_format).float()
|
||||
else:
|
||||
images = None
|
||||
example["images"] = images
|
||||
|
||||
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
|
||||
example["captions"] = captions
|
||||
|
||||
if self.debug_dataset:
|
||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||
|
||||
example["conditioning_images"] = conditioning_images.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
return example
|
||||
|
||||
# behave as Dataset mock
|
||||
class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
|
||||
@@ -1636,6 +1952,8 @@ def get_git_revision_hash() -> str:
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
replace_attentions_for_hypernetwork()
|
||||
# unet is not used currently, but it is here for future use
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
return
|
||||
if mem_eff_attn:
|
||||
unet.set_attn_processor(FlashAttnProcessor())
|
||||
elif xformers:
|
||||
|
||||
Reference in New Issue
Block a user