mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add controlnet training
This commit is contained in:
@@ -403,6 +403,54 @@ class FineTuningSubset(BaseSubset):
|
||||
return self.metadata_file == other.metadata_file
|
||||
|
||||
|
||||
class ControlNetSubset(BaseSubset):
|
||||
def __init__(
|
||||
self,
|
||||
image_dir: str,
|
||||
conditioning_data_dir: str,
|
||||
caption_extension: str,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
keep_tokens,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
random_crop,
|
||||
caption_dropout_rate,
|
||||
caption_dropout_every_n_epochs,
|
||||
caption_tag_dropout_rate,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
super().__init__(
|
||||
image_dir,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
keep_tokens,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
random_crop,
|
||||
caption_dropout_rate,
|
||||
caption_dropout_every_n_epochs,
|
||||
caption_tag_dropout_rate,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
)
|
||||
|
||||
self.conditioning_data_dir = conditioning_data_dir
|
||||
self.caption_extension = caption_extension
|
||||
if self.caption_extension and not self.caption_extension.startswith("."):
|
||||
self.caption_extension = "." + self.caption_extension
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, ControlNetSubset):
|
||||
return NotImplemented
|
||||
return self.image_dir == other.image_dir and self.conditioning_data_dir == other.conditioning_data_dir
|
||||
|
||||
|
||||
class BaseDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool
|
||||
@@ -1387,6 +1435,274 @@ class FineTuningDataset(BaseDataset):
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
|
||||
class ControlNetDataset(BaseDataset):
|
||||
def __init__(
|
||||
self,
|
||||
subsets: Sequence[ControlNetSubset],
|
||||
batch_size: int,
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
enable_bucket: bool,
|
||||
min_bucket_reso: int,
|
||||
max_bucket_reso: int,
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
self.conditioning_image_data: Dict[str, ImageInfo] = {}
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.latents_cache = None
|
||||
|
||||
self.num_reg_images = 0
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
assert (
|
||||
min(resolution) >= min_bucket_reso
|
||||
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
|
||||
assert (
|
||||
max(resolution) <= max_bucket_reso
|
||||
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
else:
|
||||
self.min_bucket_reso = None
|
||||
self.max_bucket_reso = None
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
def read_caption(img_path, caption_extension):
|
||||
# captionの候補ファイル名を作る
|
||||
base_name = os.path.splitext(img_path)[0]
|
||||
base_name_face_det = base_name
|
||||
tokens = base_name.split("_")
|
||||
if len(tokens) >= 5:
|
||||
base_name_face_det = "_".join(tokens[:-4])
|
||||
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
|
||||
|
||||
caption = None
|
||||
for cap_path in cap_paths:
|
||||
if os.path.isfile(cap_path):
|
||||
with open(cap_path, "rt", encoding="utf-8") as f:
|
||||
try:
|
||||
lines = f.readlines()
|
||||
except UnicodeDecodeError as e:
|
||||
print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
||||
raise e
|
||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||
caption = lines[0].strip()
|
||||
break
|
||||
return caption
|
||||
|
||||
def load_controlnet_dir(subset: ControlNetSubset):
|
||||
if not os.path.isdir(subset.image_dir):
|
||||
print(f"not directory: {subset.image_dir}")
|
||||
return [], []
|
||||
if not os.path.isdir(subset.conditioning_data_dir):
|
||||
print(f"not directory: {subset.conditioning_data_dir}")
|
||||
return [], []
|
||||
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
|
||||
img_paths = sorted(img_paths)
|
||||
conditioning_img_paths = sorted(conditioning_img_paths)
|
||||
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
print(f"found directory {subset.conditioning_data_dir} contains {len(conditioning_img_paths)} image files")
|
||||
|
||||
img_basenames = [os.path.basename(img) for img in img_paths]
|
||||
conditioning_img_basenames = [os.path.basename(img) for img in conditioning_img_paths]
|
||||
missing_imgs = []
|
||||
extra_imgs = []
|
||||
|
||||
for img in img_basenames:
|
||||
if img not in conditioning_img_basenames:
|
||||
missing_imgs.append(img)
|
||||
for img in conditioning_img_basenames:
|
||||
if img not in img_basenames:
|
||||
extra_imgs.append(img)
|
||||
|
||||
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
|
||||
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
|
||||
|
||||
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
missing_captions = []
|
||||
for img_path in img_paths:
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension)
|
||||
if cap_for_img is None:
|
||||
print(f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}")
|
||||
captions.append("")
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
captions.append(cap_for_img)
|
||||
|
||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
||||
|
||||
if missing_captions:
|
||||
number_of_missing_captions = len(missing_captions)
|
||||
number_of_missing_captions_to_show = 5
|
||||
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show
|
||||
|
||||
print(
|
||||
f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。"
|
||||
)
|
||||
for i, missing_caption in enumerate(missing_captions):
|
||||
if i >= number_of_missing_captions_to_show:
|
||||
print(missing_caption + f"... and {remaining_missing_captions} more")
|
||||
break
|
||||
print(missing_caption)
|
||||
return img_paths, conditioning_img_paths, captions
|
||||
|
||||
print("prepare images.")
|
||||
num_train_images = 0
|
||||
for subset in subsets:
|
||||
if subset.num_repeats < 1:
|
||||
print(
|
||||
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
|
||||
)
|
||||
continue
|
||||
|
||||
if subset in self.subsets:
|
||||
print(
|
||||
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します"
|
||||
)
|
||||
continue
|
||||
|
||||
img_paths, conditioning_img_paths, captions = load_controlnet_dir(subset)
|
||||
if len(img_paths) < 1:
|
||||
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
|
||||
continue
|
||||
|
||||
num_train_images += subset.num_repeats * len(img_paths)
|
||||
|
||||
for img_path, cond_img_path, caption in zip(img_paths, conditioning_img_paths, captions):
|
||||
info = ImageInfo(img_path, subset.num_repeats, caption, False, img_path)
|
||||
setattr(info, "cond_img_path", cond_img_path)
|
||||
self.register_image(info, subset)
|
||||
|
||||
subset.img_count = len(img_paths)
|
||||
self.subsets.append(subset)
|
||||
|
||||
print(f"{num_train_images} train images with repeating.")
|
||||
self.num_train_images = num_train_images
|
||||
|
||||
self.conditioning_image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
||||
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
||||
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
||||
|
||||
loss_weights = []
|
||||
captions = []
|
||||
input_ids_list = []
|
||||
latents_list = []
|
||||
images = []
|
||||
conditioning_images = []
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
loss_weights.append(1.0)
|
||||
|
||||
# image/latentsを処理する
|
||||
if image_info.latents is not None: # cache_latents=Trueの場合
|
||||
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
|
||||
image = None
|
||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
|
||||
latents = torch.FloatTensor(latents)
|
||||
image = None
|
||||
else:
|
||||
# 画像を読み込み、必要ならcropする
|
||||
img = self.load_image(image_info.absolute_path)
|
||||
im_h, im_w = img.shape[0:2]
|
||||
|
||||
if self.enable_bucket:
|
||||
img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
|
||||
else:
|
||||
im_h, im_w = img.shape[0:2]
|
||||
assert (
|
||||
im_h == self.height and im_w == self.width
|
||||
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
|
||||
# augmentation
|
||||
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
|
||||
if aug is not None:
|
||||
img = aug(image=img)["image"]
|
||||
|
||||
latents = None
|
||||
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
if self.XTI_layers:
|
||||
caption_layer = []
|
||||
for layer in self.XTI_layers:
|
||||
token_strings_from = " ".join(self.token_strings)
|
||||
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||
caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||
caption_layer.append(caption_)
|
||||
captions.append(caption_layer)
|
||||
else:
|
||||
captions.append(caption)
|
||||
if not self.token_padding_disabled: # this option might be omitted in future
|
||||
if self.XTI_layers:
|
||||
token_caption = self.get_input_ids(caption_layer)
|
||||
else:
|
||||
token_caption = self.get_input_ids(caption)
|
||||
input_ids_list.append(token_caption)
|
||||
|
||||
assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}"
|
||||
|
||||
cond_img = self.load_image(image_info.cond_img_path)
|
||||
if self.enable_bucket:
|
||||
cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size)
|
||||
cond_img = self.conditioning_image_transforms(cond_img)
|
||||
conditioning_images.append(cond_img)
|
||||
conditioning_images = torch.stack(conditioning_images)
|
||||
|
||||
example = {}
|
||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||
|
||||
if self.token_padding_disabled:
|
||||
# padding=True means pad in the batch
|
||||
example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||
else:
|
||||
# batch processing seems to be good
|
||||
example["input_ids"] = torch.stack(input_ids_list)
|
||||
|
||||
if images[0] is not None:
|
||||
images = torch.stack(images)
|
||||
images = images.to(memory_format=torch.contiguous_format).float()
|
||||
else:
|
||||
images = None
|
||||
example["images"] = images
|
||||
|
||||
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
|
||||
example["captions"] = captions
|
||||
|
||||
if self.debug_dataset:
|
||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||
|
||||
example["conditioning_images"] = conditioning_images.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
return example
|
||||
|
||||
# behave as Dataset mock
|
||||
class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
|
||||
@@ -1636,6 +1952,8 @@ def get_git_revision_hash() -> str:
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
replace_attentions_for_hypernetwork()
|
||||
# unet is not used currently, but it is here for future use
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
return
|
||||
if mem_eff_attn:
|
||||
unet.set_attn_processor(FlashAttnProcessor())
|
||||
elif xformers:
|
||||
|
||||
Reference in New Issue
Block a user