add controlnet training

This commit is contained in:
ddPn08
2023-05-31 14:13:15 +09:00
parent 4f8ce00477
commit 62d00b4520
4 changed files with 1075 additions and 10 deletions

View File

@@ -403,6 +403,54 @@ class FineTuningSubset(BaseSubset):
return self.metadata_file == other.metadata_file
class ControlNetSubset(BaseSubset):
def __init__(
self,
image_dir: str,
conditioning_data_dir: str,
caption_extension: str,
num_repeats,
shuffle_caption,
keep_tokens,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
super().__init__(
image_dir,
num_repeats,
shuffle_caption,
keep_tokens,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
)
self.conditioning_data_dir = conditioning_data_dir
self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
def __eq__(self, other) -> bool:
if not isinstance(other, ControlNetSubset):
return NotImplemented
return self.image_dir == other.image_dir and self.conditioning_data_dir == other.conditioning_data_dir
class BaseDataset(torch.utils.data.Dataset):
def __init__(
self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool
@@ -1387,6 +1435,274 @@ class FineTuningDataset(BaseDataset):
return npz_file_norm, npz_file_flip
class ControlNetDataset(BaseDataset):
def __init__(
self,
subsets: Sequence[ControlNetSubset],
batch_size: int,
tokenizer,
max_token_length,
resolution,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
self.conditioning_image_data: Dict[str, ImageInfo] = {}
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
self.latents_cache = None
self.num_reg_images = 0
self.enable_bucket = enable_bucket
if self.enable_bucket:
assert (
min(resolution) >= min_bucket_reso
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
assert (
max(resolution) <= max_bucket_reso
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
self.bucket_no_upscale = bucket_no_upscale
else:
self.min_bucket_reso = None
self.max_bucket_reso = None
self.bucket_reso_steps = None # この情報は使われない
self.bucket_no_upscale = False
def read_caption(img_path, caption_extension):
# captionの候補ファイル名を作る
base_name = os.path.splitext(img_path)[0]
base_name_face_det = base_name
tokens = base_name.split("_")
if len(tokens) >= 5:
base_name_face_det = "_".join(tokens[:-4])
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
caption = None
for cap_path in cap_paths:
if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding="utf-8") as f:
try:
lines = f.readlines()
except UnicodeDecodeError as e:
print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
raise e
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
caption = lines[0].strip()
break
return caption
def load_controlnet_dir(subset: ControlNetSubset):
if not os.path.isdir(subset.image_dir):
print(f"not directory: {subset.image_dir}")
return [], []
if not os.path.isdir(subset.conditioning_data_dir):
print(f"not directory: {subset.conditioning_data_dir}")
return [], []
img_paths = glob_images(subset.image_dir, "*")
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
img_paths = sorted(img_paths)
conditioning_img_paths = sorted(conditioning_img_paths)
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
print(f"found directory {subset.conditioning_data_dir} contains {len(conditioning_img_paths)} image files")
img_basenames = [os.path.basename(img) for img in img_paths]
conditioning_img_basenames = [os.path.basename(img) for img in conditioning_img_paths]
missing_imgs = []
extra_imgs = []
for img in img_basenames:
if img not in conditioning_img_basenames:
missing_imgs.append(img)
for img in conditioning_img_basenames:
if img not in img_basenames:
extra_imgs.append(img)
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension)
if cap_for_img is None:
print(f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}")
captions.append("")
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
if missing_captions:
number_of_missing_captions = len(missing_captions)
number_of_missing_captions_to_show = 5
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show
print(
f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。"
)
for i, missing_caption in enumerate(missing_captions):
if i >= number_of_missing_captions_to_show:
print(missing_caption + f"... and {remaining_missing_captions} more")
break
print(missing_caption)
return img_paths, conditioning_img_paths, captions
print("prepare images.")
num_train_images = 0
for subset in subsets:
if subset.num_repeats < 1:
print(
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
)
continue
if subset in self.subsets:
print(
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します"
)
continue
img_paths, conditioning_img_paths, captions = load_controlnet_dir(subset)
if len(img_paths) < 1:
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
continue
num_train_images += subset.num_repeats * len(img_paths)
for img_path, cond_img_path, caption in zip(img_paths, conditioning_img_paths, captions):
info = ImageInfo(img_path, subset.num_repeats, caption, False, img_path)
setattr(info, "cond_img_path", cond_img_path)
self.register_image(info, subset)
subset.img_count = len(img_paths)
self.subsets.append(subset)
print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images
self.conditioning_image_transforms = transforms.Compose(
[
transforms.ToTensor(),
]
)
def __getitem__(self, index):
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
loss_weights = []
captions = []
input_ids_list = []
latents_list = []
images = []
conditioning_images = []
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
loss_weights.append(1.0)
# image/latentsを処理する
if image_info.latents is not None: # cache_latents=Trueの場合
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
latents = torch.FloatTensor(latents)
image = None
else:
# 画像を読み込み、必要ならcropする
img = self.load_image(image_info.absolute_path)
im_h, im_w = img.shape[0:2]
if self.enable_bucket:
img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
else:
im_h, im_w = img.shape[0:2]
assert (
im_h == self.height and im_w == self.width
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# augmentation
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
if aug is not None:
img = aug(image=img)["image"]
latents = None
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
images.append(image)
latents_list.append(latents)
caption = self.process_caption(subset, image_info.caption)
if self.XTI_layers:
caption_layer = []
for layer in self.XTI_layers:
token_strings_from = " ".join(self.token_strings)
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
caption_ = caption.replace(token_strings_from, token_strings_to)
caption_layer.append(caption_)
captions.append(caption_layer)
else:
captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future
if self.XTI_layers:
token_caption = self.get_input_ids(caption_layer)
else:
token_caption = self.get_input_ids(caption)
input_ids_list.append(token_caption)
assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}"
cond_img = self.load_image(image_info.cond_img_path)
if self.enable_bucket:
cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size)
cond_img = self.conditioning_image_transforms(cond_img)
conditioning_images.append(cond_img)
conditioning_images = torch.stack(conditioning_images)
example = {}
example["loss_weights"] = torch.FloatTensor(loss_weights)
if self.token_padding_disabled:
# padding=True means pad in the batch
example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
else:
# batch processing seems to be good
example["input_ids"] = torch.stack(input_ids_list)
if images[0] is not None:
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
else:
images = None
example["images"] = images
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
example["captions"] = captions
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
example["conditioning_images"] = conditioning_images.to(memory_format=torch.contiguous_format).float()
return example
# behave as Dataset mock
class DatasetGroup(torch.utils.data.ConcatDataset):
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
@@ -1636,6 +1952,8 @@ def get_git_revision_hash() -> str:
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
replace_attentions_for_hypernetwork()
# unet is not used currently, but it is here for future use
unet.enable_xformers_memory_efficient_attention()
return
if mem_eff_attn:
unet.set_attn_processor(FlashAttnProcessor())
elif xformers: