mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
update dataset to return size, refactor ctrlnet ds
This commit is contained in:
@@ -79,7 +79,7 @@ class ControlNetSubsetParams(BaseSubsetParams):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseDatasetParams:
|
class BaseDatasetParams:
|
||||||
tokenizer: CLIPTokenizer = None
|
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
|
||||||
max_token_length: int = None
|
max_token_length: int = None
|
||||||
resolution: Optional[Tuple[int, int]] = None
|
resolution: Optional[Tuple[int, int]] = None
|
||||||
debug_dataset: bool = False
|
debug_dataset: bool = False
|
||||||
|
|||||||
@@ -1116,13 +1116,15 @@ if __name__ == "__main__":
|
|||||||
# 使用メモリ量確認用の疑似学習ループ
|
# 使用メモリ量確認用の疑似学習ループ
|
||||||
print("preparing optimizer")
|
print("preparing optimizer")
|
||||||
|
|
||||||
import bitsandbytes
|
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
|
||||||
import transformers
|
|
||||||
|
|
||||||
|
# import bitsandbytes
|
||||||
# optimizer = bitsandbytes.adam.Adam8bit(unet.parameters(), lr=1e-3) # not working
|
# optimizer = bitsandbytes.adam.Adam8bit(unet.parameters(), lr=1e-3) # not working
|
||||||
# optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
|
# optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
|
||||||
# optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
|
# optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
|
||||||
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
|
|
||||||
|
import transformers
|
||||||
|
|
||||||
optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2
|
optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2
|
||||||
|
|
||||||
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
||||||
@@ -1133,7 +1135,7 @@ if __name__ == "__main__":
|
|||||||
for step in range(steps):
|
for step in range(steps):
|
||||||
print(f"step {step}")
|
print(f"step {step}")
|
||||||
|
|
||||||
x = torch.randn(batch_size, 4, 128, 128).cuda() # 512x512
|
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
|
||||||
t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
|
t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
|
||||||
ctx = torch.randn(batch_size, 77, 2048).cuda()
|
ctx = torch.randn(batch_size, 77, 2048).cuda()
|
||||||
y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda()
|
y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda()
|
||||||
|
|||||||
@@ -103,6 +103,9 @@ class ImageInfo:
|
|||||||
self.latents_flipped: torch.Tensor = None
|
self.latents_flipped: torch.Tensor = None
|
||||||
self.latents_npz: str = None
|
self.latents_npz: str = None
|
||||||
self.latents_npz_flipped: str = None
|
self.latents_npz_flipped: str = None
|
||||||
|
self.latents_original_size: Tuple[int, int] = None # original image size, not latents size
|
||||||
|
self.latents_crop_left_top: Tuple[int, int] = None # original image crop left top, not latents crop left top
|
||||||
|
self.cond_img_path: str = None
|
||||||
|
|
||||||
|
|
||||||
class BucketManager:
|
class BucketManager:
|
||||||
@@ -171,6 +174,7 @@ class BucketManager:
|
|||||||
def select_bucket(self, image_width, image_height):
|
def select_bucket(self, image_width, image_height):
|
||||||
aspect_ratio = image_width / image_height
|
aspect_ratio = image_width / image_height
|
||||||
if not self.no_upscale:
|
if not self.no_upscale:
|
||||||
|
# 拡大および縮小を行う
|
||||||
# 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
|
# 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
|
||||||
reso = (image_width, image_height)
|
reso = (image_width, image_height)
|
||||||
if reso in self.predefined_resos_set:
|
if reso in self.predefined_resos_set:
|
||||||
@@ -189,6 +193,7 @@ class BucketManager:
|
|||||||
resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5))
|
resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5))
|
||||||
# print("use predef", image_width, image_height, reso, resized_size)
|
# print("use predef", image_width, image_height, reso, resized_size)
|
||||||
else:
|
else:
|
||||||
|
# 縮小のみを行う
|
||||||
if image_width * image_height > self.max_area:
|
if image_width * image_height > self.max_area:
|
||||||
# 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
|
# 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
|
||||||
resized_width = math.sqrt(self.max_area * aspect_ratio)
|
resized_width = math.sqrt(self.max_area * aspect_ratio)
|
||||||
@@ -238,41 +243,40 @@ class BucketBatchIndex(NamedTuple):
|
|||||||
class AugHelper:
|
class AugHelper:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# prepare all possible augmentators
|
# prepare all possible augmentators
|
||||||
color_aug_method = albu.OneOf(
|
self.color_aug_method = albu.OneOf(
|
||||||
[
|
[
|
||||||
albu.HueSaturationValue(8, 0, 0, p=0.5),
|
albu.HueSaturationValue(8, 0, 0, p=0.5),
|
||||||
albu.RandomGamma((95, 105), p=0.5),
|
albu.RandomGamma((95, 105), p=0.5),
|
||||||
],
|
],
|
||||||
p=0.33,
|
p=0.33,
|
||||||
)
|
)
|
||||||
flip_aug_method = albu.HorizontalFlip(p=0.5)
|
|
||||||
|
|
||||||
# key: (use_color_aug, use_flip_aug)
|
# key: (use_color_aug, use_flip_aug)
|
||||||
self.augmentors = {
|
# self.augmentors = {
|
||||||
(True, True): albu.Compose(
|
# (True, True): albu.Compose(
|
||||||
[
|
# [
|
||||||
color_aug_method,
|
# color_aug_method,
|
||||||
flip_aug_method,
|
# flip_aug_method,
|
||||||
],
|
# ],
|
||||||
p=1.0,
|
# p=1.0,
|
||||||
),
|
# ),
|
||||||
(True, False): albu.Compose(
|
# (True, False): albu.Compose(
|
||||||
[
|
# [
|
||||||
color_aug_method,
|
# color_aug_method,
|
||||||
],
|
# ],
|
||||||
p=1.0,
|
# p=1.0,
|
||||||
),
|
# ),
|
||||||
(False, True): albu.Compose(
|
# (False, True): albu.Compose(
|
||||||
[
|
# [
|
||||||
flip_aug_method,
|
# flip_aug_method,
|
||||||
],
|
# ],
|
||||||
p=1.0,
|
# p=1.0,
|
||||||
),
|
# ),
|
||||||
(False, False): None,
|
# (False, False): None,
|
||||||
}
|
# }
|
||||||
|
|
||||||
def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
|
def get_augmentor(self, use_color_aug: bool) -> Optional[albu.Compose]:
|
||||||
return self.augmentors[(use_color_aug, use_flip_aug)]
|
return self.color_aug_method if use_color_aug else None
|
||||||
|
|
||||||
|
|
||||||
class BaseSubset:
|
class BaseSubset:
|
||||||
@@ -454,10 +458,16 @@ class ControlNetSubset(BaseSubset):
|
|||||||
|
|
||||||
class BaseDataset(torch.utils.data.Dataset):
|
class BaseDataset(torch.utils.data.Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool
|
self,
|
||||||
|
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
|
||||||
|
max_token_length: int,
|
||||||
|
resolution: Optional[Tuple[int, int]],
|
||||||
|
debug_dataset: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = tokenizer
|
|
||||||
|
self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
|
||||||
|
|
||||||
self.max_token_length = max_token_length
|
self.max_token_length = max_token_length
|
||||||
# width/height is used when enable_bucket==False
|
# width/height is used when enable_bucket==False
|
||||||
self.width, self.height = (None, None) if resolution is None else resolution
|
self.width, self.height = (None, None) if resolution is None else resolution
|
||||||
@@ -478,7 +488,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
self.bucket_no_upscale = None
|
self.bucket_no_upscale = None
|
||||||
self.bucket_info = None # for metadata
|
self.bucket_info = None # for metadata
|
||||||
|
|
||||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2
|
||||||
|
|
||||||
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
||||||
|
|
||||||
@@ -594,48 +604,49 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
return caption
|
return caption
|
||||||
|
|
||||||
def get_input_ids(self, caption):
|
def get_input_ids(self, caption, tokenizer=None):
|
||||||
input_ids = self.tokenizer(
|
if tokenizer is None:
|
||||||
|
tokenizer = self.tokenizers[0]
|
||||||
|
|
||||||
|
input_ids = tokenizer(
|
||||||
caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt"
|
caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt"
|
||||||
).input_ids
|
).input_ids
|
||||||
|
|
||||||
if self.tokenizer_max_length > self.tokenizer.model_max_length:
|
if self.tokenizer_max_length > tokenizer.model_max_length:
|
||||||
input_ids = input_ids.squeeze(0)
|
input_ids = input_ids.squeeze(0)
|
||||||
iids_list = []
|
iids_list = []
|
||||||
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
if tokenizer.pad_token_id == tokenizer.eos_token_id:
|
||||||
# v1
|
# v1
|
||||||
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
||||||
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
||||||
for i in range(
|
for i in range(
|
||||||
1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2
|
1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2
|
||||||
): # (1, 152, 75)
|
): # (1, 152, 75)
|
||||||
ids_chunk = (
|
ids_chunk = (
|
||||||
input_ids[0].unsqueeze(0),
|
input_ids[0].unsqueeze(0),
|
||||||
input_ids[i : i + self.tokenizer.model_max_length - 2],
|
input_ids[i : i + tokenizer.model_max_length - 2],
|
||||||
input_ids[-1].unsqueeze(0),
|
input_ids[-1].unsqueeze(0),
|
||||||
)
|
)
|
||||||
ids_chunk = torch.cat(ids_chunk)
|
ids_chunk = torch.cat(ids_chunk)
|
||||||
iids_list.append(ids_chunk)
|
iids_list.append(ids_chunk)
|
||||||
else:
|
else:
|
||||||
# v2
|
# v2 or SDXL
|
||||||
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
|
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
|
||||||
for i in range(
|
for i in range(1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
|
||||||
1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2
|
|
||||||
):
|
|
||||||
ids_chunk = (
|
ids_chunk = (
|
||||||
input_ids[0].unsqueeze(0), # BOS
|
input_ids[0].unsqueeze(0), # BOS
|
||||||
input_ids[i : i + self.tokenizer.model_max_length - 2],
|
input_ids[i : i + tokenizer.model_max_length - 2],
|
||||||
input_ids[-1].unsqueeze(0),
|
input_ids[-1].unsqueeze(0),
|
||||||
) # PAD or EOS
|
) # PAD or EOS
|
||||||
ids_chunk = torch.cat(ids_chunk)
|
ids_chunk = torch.cat(ids_chunk)
|
||||||
|
|
||||||
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
|
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
|
||||||
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
|
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
|
||||||
if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
|
if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
|
||||||
ids_chunk[-1] = self.tokenizer.eos_token_id
|
ids_chunk[-1] = tokenizer.eos_token_id
|
||||||
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
|
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
|
||||||
if ids_chunk[1] == self.tokenizer.pad_token_id:
|
if ids_chunk[1] == tokenizer.pad_token_id:
|
||||||
ids_chunk[1] = self.tokenizer.eos_token_id
|
ids_chunk[1] = tokenizer.eos_token_id
|
||||||
|
|
||||||
iids_list.append(ids_chunk)
|
iids_list.append(ids_chunk)
|
||||||
|
|
||||||
@@ -755,46 +766,58 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
img = np.array(image, np.uint8)
|
img = np.array(image, np.uint8)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size, cond_img = None):
|
def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size):
|
||||||
image_height, image_width = image.shape[0:2]
|
image_height, image_width = image.shape[0:2]
|
||||||
|
|
||||||
if image_width != resized_size[0] or image_height != resized_size[1]:
|
if image_width != resized_size[0] or image_height != resized_size[1]:
|
||||||
# リサイズする
|
# リサイズする
|
||||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||||
if exists(cond_img):
|
|
||||||
cond_img = cv2.resize(cond_img, resized_size, interpolation=cv2.INTER_AREA)
|
|
||||||
|
|
||||||
image_height, image_width = image.shape[0:2]
|
image_height, image_width = image.shape[0:2]
|
||||||
|
original_size = (image_width, image_height)
|
||||||
|
|
||||||
|
crop_left_top = (0, 0)
|
||||||
if image_width > reso[0]:
|
if image_width > reso[0]:
|
||||||
trim_size = image_width - reso[0]
|
trim_size = image_width - reso[0]
|
||||||
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
||||||
# print("w", trim_size, p)
|
# print("w", trim_size, p)
|
||||||
image = image[:, p : p + reso[0]]
|
image = image[:, p : p + reso[0]]
|
||||||
if exists(cond_img):
|
crop_left_top = (p, 0)
|
||||||
cond_img = cond_img[:, p : p + reso[0]]
|
|
||||||
if image_height > reso[1]:
|
if image_height > reso[1]:
|
||||||
trim_size = image_height - reso[1]
|
trim_size = image_height - reso[1]
|
||||||
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
||||||
# print("h", trim_size, p)
|
# print("h", trim_size, p)
|
||||||
image = image[p : p + reso[1]]
|
image = image[p : p + reso[1]]
|
||||||
if exists(cond_img):
|
crop_left_top = (crop_left_top[0], p)
|
||||||
cond_img = cond_img[p : p + reso[1]]
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
image.shape[0] == reso[1] and image.shape[1] == reso[0]
|
image.shape[0] == reso[1] and image.shape[1] == reso[0]
|
||||||
), f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
), f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||||
|
return image, original_size, crop_left_top
|
||||||
if exists(cond_img):
|
|
||||||
assert (
|
|
||||||
cond_img.shape[0] == reso[1] and cond_img.shape[1] == reso[0]
|
|
||||||
), f"internal error, illegal trimmed size: {cond_img.shape}, {reso}"
|
|
||||||
return image, cond_img
|
|
||||||
|
|
||||||
return image
|
|
||||||
|
|
||||||
def is_latent_cacheable(self):
|
def is_latent_cacheable(self):
|
||||||
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
||||||
|
|
||||||
|
def is_disk_cached_latents_is_expected(self, reso, npz_path, flipped_npz_path):
|
||||||
|
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
|
||||||
|
|
||||||
|
for npath in [npz_path, flipped_npz_path]:
|
||||||
|
if npath is None:
|
||||||
|
continue
|
||||||
|
if not os.path.exists(npath):
|
||||||
|
return False
|
||||||
|
|
||||||
|
npz = np.load(npath)
|
||||||
|
if "latents" not in npz or "original_size" not in npz or "crop_left_top" not in npz: # old ver?
|
||||||
|
return False
|
||||||
|
|
||||||
|
cached_latents = npz["latents"]
|
||||||
|
|
||||||
|
if cached_latents.shape[1:3] != expected_latents_size:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||||
# ちょっと速くした
|
# ちょっと速くした
|
||||||
print("caching latents.")
|
print("caching latents.")
|
||||||
@@ -811,38 +834,26 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
subset = self.image_to_subset[info.image_key]
|
subset = self.image_to_subset[info.image_key]
|
||||||
|
|
||||||
if info.latents_npz is not None:
|
if info.latents_npz is not None:
|
||||||
info.latents = self.load_latents_from_npz(info, False)
|
info.latents, info.latents_original_size, info.latents_crop_left_top = self.load_latents_from_npz(info, False)
|
||||||
info.latents = torch.FloatTensor(info.latents)
|
info.latents = torch.FloatTensor(info.latents)
|
||||||
|
|
||||||
# might be None, but that's ok because check is done in dataset
|
info.latents_flipped, _, _ = self.load_latents_from_npz(info, True) # might be None
|
||||||
info.latents_flipped = self.load_latents_from_npz(info, True)
|
|
||||||
if info.latents_flipped is not None:
|
if info.latents_flipped is not None:
|
||||||
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# check disk cache exists and size of latents
|
# check disk cache exists and size of latents
|
||||||
if cache_to_disk:
|
if cache_to_disk:
|
||||||
# TODO: refactor to unify with FineTuningDataset
|
|
||||||
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
|
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
|
||||||
info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz"
|
info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz"
|
||||||
if not is_main_process:
|
if not is_main_process:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
cache_available = False
|
cache_available = self.is_disk_cached_latents_is_expected(
|
||||||
expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意
|
info.bucket_reso, info.latents_npz, info.latents_npz_flipped if self.flip_aug else None
|
||||||
if os.path.exists(info.latents_npz):
|
)
|
||||||
cached_latents = np.load(info.latents_npz)["arr_0"]
|
|
||||||
if cached_latents.shape[1:3] == expected_latents_size:
|
|
||||||
cache_available = True
|
|
||||||
|
|
||||||
if subset.flip_aug:
|
if cache_available: # do not add to batch
|
||||||
cache_available = False
|
|
||||||
if os.path.exists(info.latents_npz_flipped):
|
|
||||||
cached_latents_flipped = np.load(info.latents_npz_flipped)["arr_0"]
|
|
||||||
if cached_latents_flipped.shape[1:3] == expected_latents_size:
|
|
||||||
cache_available = True
|
|
||||||
|
|
||||||
if cache_available:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# if last member of batch has different resolution, flush the batch
|
# if last member of batch has different resolution, flush the batch
|
||||||
@@ -868,10 +879,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
images = []
|
images = []
|
||||||
for info in batch:
|
for info in batch:
|
||||||
image = self.load_image(info.absolute_path)
|
image = self.load_image(info.absolute_path)
|
||||||
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
|
image, original_size, crop_left_top = self.trim_and_resize_if_required(
|
||||||
|
subset, image, info.bucket_reso, info.resized_size
|
||||||
|
)
|
||||||
image = self.image_transforms(image)
|
image = self.image_transforms(image)
|
||||||
images.append(image)
|
images.append(image)
|
||||||
|
|
||||||
|
info.latents_original_size = original_size
|
||||||
|
info.latents_crop_left_top = crop_left_top
|
||||||
|
|
||||||
img_tensors = torch.stack(images, dim=0)
|
img_tensors = torch.stack(images, dim=0)
|
||||||
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
|
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
|
||||||
|
|
||||||
@@ -879,7 +895,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
for info, latent in zip(batch, latents):
|
for info, latent in zip(batch, latents):
|
||||||
if cache_to_disk:
|
if cache_to_disk:
|
||||||
np.savez(info.latents_npz, latent.float().numpy())
|
np.savez(
|
||||||
|
info.latents_npz,
|
||||||
|
latents=latent.float().numpy(),
|
||||||
|
original_size=np.array(info.latents_original_size),
|
||||||
|
crop_left_top=np.array(info.latents_crop_left_top),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
info.latents = latent
|
info.latents = latent
|
||||||
|
|
||||||
@@ -888,7 +909,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||||
for info, latent in zip(batch, latents):
|
for info, latent in zip(batch, latents):
|
||||||
if cache_to_disk:
|
if cache_to_disk:
|
||||||
np.savez(info.latents_npz_flipped, latent.float().numpy())
|
np.savez(
|
||||||
|
info.latents_npz_flipped,
|
||||||
|
latents=latent.float().numpy(),
|
||||||
|
original_size=np.array(info.latents_original_size),
|
||||||
|
crop_left_top=np.array(info.latents_crop_left_top), # reverse horizontally when use flipped latents
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
info.latents_flipped = latent
|
info.latents_flipped = latent
|
||||||
|
|
||||||
@@ -961,8 +987,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
|
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
|
||||||
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
|
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
|
||||||
if npz_file is None:
|
if npz_file is None:
|
||||||
return None
|
return None, None, None
|
||||||
return np.load(npz_file)["arr_0"]
|
|
||||||
|
npz = np.load(npz_file)
|
||||||
|
latents = npz["latents"]
|
||||||
|
original_size = npz["original_size"].tolist()
|
||||||
|
crop_left_top = npz["crop_left_top"].tolist()
|
||||||
|
return latents, original_size, crop_left_top
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self._length
|
return self._length
|
||||||
@@ -975,21 +1006,35 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
loss_weights = []
|
loss_weights = []
|
||||||
captions = []
|
captions = []
|
||||||
input_ids_list = []
|
input_ids_list = []
|
||||||
|
input_ids2_list = []
|
||||||
latents_list = []
|
latents_list = []
|
||||||
images = []
|
images = []
|
||||||
|
original_sizes_hw = []
|
||||||
|
crop_top_lefts = []
|
||||||
|
target_sizes_hw = []
|
||||||
|
flippeds = [] # 変数名が微妙
|
||||||
|
|
||||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||||
image_info = self.image_data[image_key]
|
image_info = self.image_data[image_key]
|
||||||
subset = self.image_to_subset[image_key]
|
subset = self.image_to_subset[image_key]
|
||||||
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
||||||
|
|
||||||
|
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
|
||||||
|
|
||||||
# image/latentsを処理する
|
# image/latentsを処理する
|
||||||
if image_info.latents is not None: # cache_latents=Trueの場合
|
if image_info.latents is not None: # cache_latents=Trueの場合
|
||||||
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
|
original_size = image_info.latents_original_size
|
||||||
|
crop_left_top = image_info.latents_crop_left_top # calc values later if flipped
|
||||||
|
if not flipped:
|
||||||
|
latents = image_info.latents
|
||||||
|
else:
|
||||||
|
latents = image_info.latents_flipped
|
||||||
|
|
||||||
image = None
|
image = None
|
||||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||||
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
|
latents, original_size, crop_left_top = self.load_latents_from_npz(image_info, flipped)
|
||||||
latents = torch.FloatTensor(latents)
|
latents = torch.FloatTensor(latents)
|
||||||
|
|
||||||
image = None
|
image = None
|
||||||
else:
|
else:
|
||||||
# 画像を読み込み、必要ならcropする
|
# 画像を読み込み、必要ならcropする
|
||||||
@@ -997,7 +1042,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
im_h, im_w = img.shape[0:2]
|
im_h, im_w = img.shape[0:2]
|
||||||
|
|
||||||
if self.enable_bucket:
|
if self.enable_bucket:
|
||||||
img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
|
img, original_size, crop_left_top = self.trim_and_resize_if_required(
|
||||||
|
subset, img, image_info.bucket_reso, image_info.resized_size
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if face_cx > 0: # 顔位置情報あり
|
if face_cx > 0: # 顔位置情報あり
|
||||||
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
|
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
|
||||||
@@ -1017,17 +1064,33 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
im_h == self.height and im_w == self.width
|
im_h == self.height and im_w == self.width
|
||||||
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||||
|
|
||||||
|
original_size = [im_w, im_h]
|
||||||
|
crop_left_top = [0, 0]
|
||||||
|
|
||||||
# augmentation
|
# augmentation
|
||||||
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
|
aug = self.aug_helper.get_augmentor(subset.color_aug)
|
||||||
if aug is not None:
|
if aug is not None:
|
||||||
img = aug(image=img)["image"]
|
img = aug(image=img)["image"]
|
||||||
|
|
||||||
|
if flipped:
|
||||||
|
img = img[:, ::-1, :].copy() # copy to avoid negative stride problem
|
||||||
|
|
||||||
latents = None
|
latents = None
|
||||||
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
||||||
|
|
||||||
images.append(image)
|
images.append(image)
|
||||||
latents_list.append(latents)
|
latents_list.append(latents)
|
||||||
|
|
||||||
|
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
|
||||||
|
|
||||||
|
if flipped:
|
||||||
|
crop_left_top = (original_size[0] - crop_left_top[0] - target_size[0], crop_left_top[1])
|
||||||
|
|
||||||
|
original_sizes_hw.append((original_size[1], original_size[0]))
|
||||||
|
crop_top_lefts.append((crop_left_top[1], crop_left_top[0]))
|
||||||
|
target_sizes_hw.append((target_size[1], target_size[0]))
|
||||||
|
flippeds.append(flipped)
|
||||||
|
|
||||||
caption = self.process_caption(subset, image_info.caption)
|
caption = self.process_caption(subset, image_info.caption)
|
||||||
if self.XTI_layers:
|
if self.XTI_layers:
|
||||||
caption_layer = []
|
caption_layer = []
|
||||||
@@ -1039,22 +1102,33 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
captions.append(caption_layer)
|
captions.append(caption_layer)
|
||||||
else:
|
else:
|
||||||
captions.append(caption)
|
captions.append(caption)
|
||||||
|
|
||||||
if not self.token_padding_disabled: # this option might be omitted in future
|
if not self.token_padding_disabled: # this option might be omitted in future
|
||||||
if self.XTI_layers:
|
if self.XTI_layers:
|
||||||
token_caption = self.get_input_ids(caption_layer)
|
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
||||||
else:
|
else:
|
||||||
token_caption = self.get_input_ids(caption)
|
token_caption = self.get_input_ids(caption, self.tokenizers[0])
|
||||||
input_ids_list.append(token_caption)
|
input_ids_list.append(token_caption)
|
||||||
|
|
||||||
|
if len(self.tokenizers) > 1:
|
||||||
|
if self.XTI_layers:
|
||||||
|
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
|
||||||
|
else:
|
||||||
|
token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||||
|
input_ids2_list.append(token_caption2)
|
||||||
|
|
||||||
example = {}
|
example = {}
|
||||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||||
|
|
||||||
if self.token_padding_disabled:
|
if self.token_padding_disabled:
|
||||||
# padding=True means pad in the batch
|
# padding=True means pad in the batch
|
||||||
example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||||
|
if len(self.tokenizers) > 1:
|
||||||
|
# following may not work in SDXL, keep the line for future update
|
||||||
|
example["input_ids2"] = self.tokenizer[1](captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||||
else:
|
else:
|
||||||
# batch processing seems to be good
|
|
||||||
example["input_ids"] = torch.stack(input_ids_list)
|
example["input_ids"] = torch.stack(input_ids_list)
|
||||||
|
example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None
|
||||||
|
|
||||||
if images[0] is not None:
|
if images[0] is not None:
|
||||||
images = torch.stack(images)
|
images = torch.stack(images)
|
||||||
@@ -1066,6 +1140,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
|
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
|
||||||
example["captions"] = captions
|
example["captions"] = captions
|
||||||
|
|
||||||
|
example["original_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in original_sizes_hw])
|
||||||
|
example["crop_top_lefts"] = torch.stack([torch.LongTensor(x) for x in crop_top_lefts])
|
||||||
|
example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw])
|
||||||
|
example["flippeds"] = flippeds
|
||||||
|
|
||||||
if self.debug_dataset:
|
if self.debug_dataset:
|
||||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||||
return example
|
return example
|
||||||
@@ -1462,151 +1541,86 @@ class ControlNetDataset(BaseDataset):
|
|||||||
max_bucket_reso: int,
|
max_bucket_reso: int,
|
||||||
bucket_reso_steps: int,
|
bucket_reso_steps: int,
|
||||||
bucket_no_upscale: bool,
|
bucket_no_upscale: bool,
|
||||||
debug_dataset) -> None:
|
debug_dataset,
|
||||||
|
) -> None:
|
||||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||||
self.conditioning_image_data: Dict[str, ImageInfo] = {}
|
|
||||||
|
|
||||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
db_subsets = []
|
||||||
|
for subset in subsets:
|
||||||
|
db_subset = DreamBoothSubset(
|
||||||
|
subset.image_dir,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
subset.caption_extension,
|
||||||
|
subset.num_repeats,
|
||||||
|
subset.shuffle_caption,
|
||||||
|
subset.keep_tokens,
|
||||||
|
subset.color_aug,
|
||||||
|
subset.flip_aug,
|
||||||
|
subset.face_crop_aug_range,
|
||||||
|
subset.random_crop,
|
||||||
|
subset.caption_dropout_rate,
|
||||||
|
subset.caption_dropout_every_n_epochs,
|
||||||
|
subset.caption_tag_dropout_rate,
|
||||||
|
subset.token_warmup_min,
|
||||||
|
subset.token_warmup_step,
|
||||||
|
)
|
||||||
|
db_subsets.append(db_subset)
|
||||||
|
|
||||||
|
self.dreambooth_dataset_delegate = DreamBoothDataset(
|
||||||
|
db_subsets,
|
||||||
|
batch_size,
|
||||||
|
tokenizer,
|
||||||
|
max_token_length,
|
||||||
|
resolution,
|
||||||
|
enable_bucket,
|
||||||
|
min_bucket_reso,
|
||||||
|
max_bucket_reso,
|
||||||
|
bucket_reso_steps,
|
||||||
|
bucket_no_upscale,
|
||||||
|
1.0,
|
||||||
|
debug_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
|
||||||
|
self.image_data = self.dreambooth_dataset_delegate.image_data
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.size = min(self.width, self.height) # 短いほう
|
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
|
||||||
self.latents_cache = None
|
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
|
||||||
|
|
||||||
self.num_reg_images = 0
|
# assert all conditioning data exists
|
||||||
|
missing_imgs = []
|
||||||
self.enable_bucket = enable_bucket
|
cond_imgs_with_img = set()
|
||||||
if self.enable_bucket:
|
for image_key, info in self.dreambooth_dataset_delegate.image_data.items():
|
||||||
assert (
|
db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key]
|
||||||
min(resolution) >= min_bucket_reso
|
subset = None
|
||||||
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
|
for s in subsets:
|
||||||
assert (
|
if s.image_dir == db_subset.image_dir:
|
||||||
max(resolution) <= max_bucket_reso
|
subset = s
|
||||||
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
|
||||||
self.min_bucket_reso = min_bucket_reso
|
|
||||||
self.max_bucket_reso = max_bucket_reso
|
|
||||||
self.bucket_reso_steps = bucket_reso_steps
|
|
||||||
self.bucket_no_upscale = bucket_no_upscale
|
|
||||||
else:
|
|
||||||
self.min_bucket_reso = None
|
|
||||||
self.max_bucket_reso = None
|
|
||||||
self.bucket_reso_steps = None # この情報は使われない
|
|
||||||
self.bucket_no_upscale = False
|
|
||||||
|
|
||||||
def read_caption(img_path, caption_extension):
|
|
||||||
# captionの候補ファイル名を作る
|
|
||||||
base_name = os.path.splitext(img_path)[0]
|
|
||||||
base_name_face_det = base_name
|
|
||||||
tokens = base_name.split("_")
|
|
||||||
if len(tokens) >= 5:
|
|
||||||
base_name_face_det = "_".join(tokens[:-4])
|
|
||||||
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
|
|
||||||
|
|
||||||
caption = None
|
|
||||||
for cap_path in cap_paths:
|
|
||||||
if os.path.isfile(cap_path):
|
|
||||||
with open(cap_path, "rt", encoding="utf-8") as f:
|
|
||||||
try:
|
|
||||||
lines = f.readlines()
|
|
||||||
except UnicodeDecodeError as e:
|
|
||||||
print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
|
||||||
raise e
|
|
||||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
|
||||||
caption = lines[0].strip()
|
|
||||||
break
|
break
|
||||||
return caption
|
assert subset is not None, "internal error: subset not found"
|
||||||
|
|
||||||
def load_controlnet_dir(subset: ControlNetSubset):
|
|
||||||
if not os.path.isdir(subset.image_dir):
|
|
||||||
print(f"not directory: {subset.image_dir}")
|
|
||||||
return [], []
|
|
||||||
if not os.path.isdir(subset.conditioning_data_dir):
|
if not os.path.isdir(subset.conditioning_data_dir):
|
||||||
print(f"not directory: {subset.conditioning_data_dir}")
|
print(f"not directory: {subset.conditioning_data_dir}")
|
||||||
return [], []
|
continue
|
||||||
|
|
||||||
img_paths = glob_images(subset.image_dir, "*")
|
img_basename = os.path.basename(info.absolute_path)
|
||||||
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
|
ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename)
|
||||||
img_paths = sorted(img_paths)
|
if not os.path.exists(ctrl_img_path):
|
||||||
conditioning_img_paths = sorted(conditioning_img_paths)
|
missing_imgs.append(img_basename)
|
||||||
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
|
||||||
print(f"found directory {subset.conditioning_data_dir} contains {len(conditioning_img_paths)} image files")
|
|
||||||
|
|
||||||
img_basenames = [os.path.basename(img) for img in img_paths]
|
info.cond_img_path = ctrl_img_path
|
||||||
conditioning_img_basenames = [os.path.basename(img) for img in conditioning_img_paths]
|
cond_imgs_with_img.add(ctrl_img_path)
|
||||||
missing_imgs = []
|
|
||||||
extra_imgs = []
|
|
||||||
|
|
||||||
for img in img_basenames:
|
extra_imgs = []
|
||||||
if img not in conditioning_img_basenames:
|
|
||||||
missing_imgs.append(img)
|
|
||||||
for img in conditioning_img_basenames:
|
|
||||||
if img not in img_basenames:
|
|
||||||
extra_imgs.append(img)
|
|
||||||
|
|
||||||
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
|
|
||||||
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
|
|
||||||
|
|
||||||
|
|
||||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
|
||||||
captions = []
|
|
||||||
missing_captions = []
|
|
||||||
for img_path in img_paths:
|
|
||||||
cap_for_img = read_caption(img_path, subset.caption_extension)
|
|
||||||
if cap_for_img is None:
|
|
||||||
print(f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}")
|
|
||||||
captions.append("")
|
|
||||||
missing_captions.append(img_path)
|
|
||||||
else:
|
|
||||||
captions.append(cap_for_img)
|
|
||||||
|
|
||||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
|
||||||
|
|
||||||
if missing_captions:
|
|
||||||
number_of_missing_captions = len(missing_captions)
|
|
||||||
number_of_missing_captions_to_show = 5
|
|
||||||
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。"
|
|
||||||
)
|
|
||||||
for i, missing_caption in enumerate(missing_captions):
|
|
||||||
if i >= number_of_missing_captions_to_show:
|
|
||||||
print(missing_caption + f"... and {remaining_missing_captions} more")
|
|
||||||
break
|
|
||||||
print(missing_caption)
|
|
||||||
return img_paths, conditioning_img_paths, captions
|
|
||||||
|
|
||||||
print("prepare images.")
|
|
||||||
num_train_images = 0
|
|
||||||
for subset in subsets:
|
for subset in subsets:
|
||||||
if subset.num_repeats < 1:
|
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
|
||||||
print(
|
extra_imgs.extend(
|
||||||
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
|
[cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img]
|
||||||
)
|
)
|
||||||
continue
|
|
||||||
|
|
||||||
if subset in self.subsets:
|
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
|
||||||
print(
|
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
|
||||||
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
img_paths, conditioning_img_paths, captions = load_controlnet_dir(subset)
|
|
||||||
if len(img_paths) < 1:
|
|
||||||
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
|
|
||||||
continue
|
|
||||||
|
|
||||||
num_train_images += subset.num_repeats * len(img_paths)
|
|
||||||
|
|
||||||
for img_path, cond_img_path, caption in zip(img_paths, conditioning_img_paths, captions):
|
|
||||||
info = ImageInfo(img_path, subset.num_repeats, caption, False, img_path)
|
|
||||||
setattr(info, "cond_img_path", cond_img_path)
|
|
||||||
self.register_image(info, subset)
|
|
||||||
|
|
||||||
subset.img_count = len(img_paths)
|
|
||||||
self.subsets.append(subset)
|
|
||||||
|
|
||||||
print(f"{num_train_images} train images with repeating.")
|
|
||||||
self.num_train_images = num_train_images
|
|
||||||
|
|
||||||
self.conditioning_image_transforms = transforms.Compose(
|
self.conditioning_image_transforms = transforms.Compose(
|
||||||
[
|
[
|
||||||
@@ -1614,88 +1628,58 @@ class ControlNetDataset(BaseDataset):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def make_buckets(self):
|
||||||
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
self.dreambooth_dataset_delegate.make_buckets()
|
||||||
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
|
||||||
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.dreambooth_dataset_delegate.__len__()
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
example = self.dreambooth_dataset_delegate[index]
|
||||||
|
|
||||||
|
bucket = self.dreambooth_dataset_delegate.bucket_manager.buckets[
|
||||||
|
self.dreambooth_dataset_delegate.buckets_indices[index].bucket_index
|
||||||
|
]
|
||||||
|
bucket_batch_size = self.dreambooth_dataset_delegate.buckets_indices[index].bucket_batch_size
|
||||||
|
image_index = self.dreambooth_dataset_delegate.buckets_indices[index].batch_index * bucket_batch_size
|
||||||
|
|
||||||
loss_weights = []
|
|
||||||
captions = []
|
|
||||||
input_ids_list = []
|
|
||||||
latents_list = []
|
|
||||||
images = []
|
|
||||||
conditioning_images = []
|
conditioning_images = []
|
||||||
|
|
||||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
for i, image_key in enumerate(bucket[image_index : image_index + bucket_batch_size]):
|
||||||
image_info = self.image_data[image_key]
|
image_info = self.dreambooth_dataset_delegate.image_data[image_key]
|
||||||
subset = self.image_to_subset[image_key]
|
|
||||||
loss_weights.append(1.0)
|
|
||||||
|
|
||||||
assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}"
|
target_size_hw = example["target_sizes_hw"][i]
|
||||||
|
original_size_hw = example["original_sizes_hw"][i]
|
||||||
|
crop_top_left = example["crop_top_lefts"][i]
|
||||||
|
flipped = example["flippeds"][i]
|
||||||
|
cond_img = self.load_image(image_info.cond_img_path)
|
||||||
|
|
||||||
# image/latentsを処理する
|
if self.dreambooth_dataset_delegate.enable_bucket:
|
||||||
if image_info.latents is not None: # cache_latents=Trueの場合
|
cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||||
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
|
assert (
|
||||||
image = None
|
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
|
||||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
|
||||||
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
|
ct, cl = crop_top_left
|
||||||
latents = torch.FloatTensor(latents)
|
h, w = target_size_hw
|
||||||
image = None
|
cond_img = cond_img[ct : ct + h, cl : cl + w]
|
||||||
else:
|
else:
|
||||||
# 画像を読み込み、必要ならcropする
|
assert (
|
||||||
img = self.load_image(image_info.absolute_path)
|
cond_img.shape[0] == self.height and cond_img.shape[1] == self.width
|
||||||
cond_img = self.load_image(image_info.cond_img_path)
|
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||||
im_h, im_w = img.shape[0:2]
|
|
||||||
|
|
||||||
if self.enable_bucket:
|
if flipped:
|
||||||
img, cond_img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size, cond_img=cond_img)
|
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
|
||||||
else:
|
|
||||||
im_h, im_w = img.shape[0:2]
|
|
||||||
assert (
|
|
||||||
im_h == self.height and im_w == self.width
|
|
||||||
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
|
||||||
|
|
||||||
# augmentation
|
|
||||||
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
|
|
||||||
if aug is not None:
|
|
||||||
img = aug(image=img)["image"]
|
|
||||||
|
|
||||||
latents = None
|
|
||||||
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
|
||||||
|
|
||||||
images.append(image)
|
|
||||||
latents_list.append(latents)
|
|
||||||
|
|
||||||
cond_img = self.conditioning_image_transforms(cond_img)
|
cond_img = self.conditioning_image_transforms(cond_img)
|
||||||
conditioning_images.append(cond_img)
|
conditioning_images.append(cond_img)
|
||||||
|
|
||||||
caption = self.process_caption(subset, image_info.caption)
|
|
||||||
captions.append(caption)
|
|
||||||
token_caption = self.get_input_ids(caption)
|
|
||||||
input_ids_list.append(token_caption)
|
|
||||||
|
|
||||||
example = {}
|
|
||||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
|
||||||
|
|
||||||
example["input_ids"] = torch.stack(input_ids_list)
|
|
||||||
|
|
||||||
if images[0] is not None:
|
|
||||||
images = torch.stack(images)
|
|
||||||
images = images.to(memory_format=torch.contiguous_format).float()
|
|
||||||
else:
|
|
||||||
images = None
|
|
||||||
example["images"] = images
|
|
||||||
|
|
||||||
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
|
|
||||||
example["captions"] = captions
|
|
||||||
|
|
||||||
if self.debug_dataset:
|
|
||||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
|
||||||
|
|
||||||
example["conditioning_images"] = torch.stack(conditioning_images).to(memory_format=torch.contiguous_format).float()
|
example["conditioning_images"] = torch.stack(conditioning_images).to(memory_format=torch.contiguous_format).float()
|
||||||
|
|
||||||
return example
|
return example
|
||||||
|
|
||||||
|
|
||||||
# behave as Dataset mock
|
# behave as Dataset mock
|
||||||
class DatasetGroup(torch.utils.data.ConcatDataset):
|
class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||||
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
|
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
|
||||||
@@ -1773,18 +1757,42 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
|||||||
example = train_dataset[idx]
|
example = train_dataset[idx]
|
||||||
if example["latents"] is not None:
|
if example["latents"] is not None:
|
||||||
print(f"sample has latents from npz file: {example['latents'].size()}")
|
print(f"sample has latents from npz file: {example['latents'].size()}")
|
||||||
for j, (ik, cap, lw, iid) in enumerate(
|
for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate(
|
||||||
zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"])
|
zip(
|
||||||
|
example["image_keys"],
|
||||||
|
example["captions"],
|
||||||
|
example["loss_weights"],
|
||||||
|
example["input_ids"],
|
||||||
|
example["original_sizes_hw"],
|
||||||
|
example["crop_top_lefts"],
|
||||||
|
example["target_sizes_hw"],
|
||||||
|
example["flippeds"],
|
||||||
|
)
|
||||||
):
|
):
|
||||||
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
|
print(
|
||||||
|
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop left top: {crptl}, target size: {trgsz}, flipped: {flpdz}'
|
||||||
|
)
|
||||||
|
|
||||||
if show_input_ids:
|
if show_input_ids:
|
||||||
print(f"input ids: {iid}")
|
print(f"input ids: {iid}")
|
||||||
|
if "input_ids2" in example:
|
||||||
|
print(f"input ids2: {example['input_ids2'][j]}")
|
||||||
if example["images"] is not None:
|
if example["images"] is not None:
|
||||||
im = example["images"][j]
|
im = example["images"][j]
|
||||||
print(f"image size: {im.size()}")
|
print(f"image size: {im.size()}")
|
||||||
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
||||||
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
|
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
|
||||||
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
|
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
|
||||||
|
|
||||||
|
if "conditioning_images" in example:
|
||||||
|
cond_img = example["conditioning_images"][j]
|
||||||
|
print(f"conditioning image size: {cond_img.size()}")
|
||||||
|
cond_img = (cond_img.numpy() * 255.0).astype(np.uint8)
|
||||||
|
cond_img = np.transpose(cond_img, (1, 2, 0))
|
||||||
|
cond_img = cond_img[:, :, ::-1]
|
||||||
|
if os.name == "nt":
|
||||||
|
cv2.imshow("cond_img", cond_img)
|
||||||
|
|
||||||
if os.name == "nt": # only windows
|
if os.name == "nt": # only windows
|
||||||
cv2.imshow("img", im)
|
cv2.imshow("img", im)
|
||||||
k = cv2.waitKey()
|
k = cv2.waitKey()
|
||||||
@@ -2011,7 +2019,6 @@ def get_git_revision_hash() -> str:
|
|||||||
return "(unknown)"
|
return "(unknown)"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
# def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||||
# replace_attentions_for_hypernetwork()
|
# replace_attentions_for_hypernetwork()
|
||||||
# # unet is not used currently, but it is here for future use
|
# # unet is not used currently, but it is here for future use
|
||||||
@@ -2063,8 +2070,9 @@ def get_git_revision_hash() -> str:
|
|||||||
# out = self.to_out[1](out)
|
# out = self.to_out[1](out)
|
||||||
# return out
|
# return out
|
||||||
|
|
||||||
|
|
||||||
# diffusers.models.attention.CrossAttention.forward = forward_xformers
|
# diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||||
def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||||
if mem_eff_attn:
|
if mem_eff_attn:
|
||||||
print("Enable memory efficient attention for U-Net")
|
print("Enable memory efficient attention for U-Net")
|
||||||
unet.set_use_memory_efficient_attention(False, True)
|
unet.set_use_memory_efficient_attention(False, True)
|
||||||
@@ -2080,6 +2088,7 @@ def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers, sdpa
|
|||||||
print("Enable SDPA for U-Net")
|
print("Enable SDPA for U-Net")
|
||||||
unet.set_use_sdpa(True)
|
unet.set_use_sdpa(True)
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
||||||
# vae is not used currently, but it is here for future use
|
# vae is not used currently, but it is here for future use
|
||||||
@@ -2327,7 +2336,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
|
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
|
||||||
)
|
)
|
||||||
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||||
parser.add_argument("--sdpa", action="store_true", help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)")
|
parser.add_argument(
|
||||||
|
"--sdpa",
|
||||||
|
action="store_true",
|
||||||
|
help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ"
|
"--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ"
|
||||||
)
|
)
|
||||||
@@ -3231,7 +3244,9 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
|
|||||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||||
if load_stable_diffusion_format:
|
if load_stable_diffusion_format:
|
||||||
print(f"load StableDiffusion checkpoint: {name_or_path}")
|
print(f"load StableDiffusion checkpoint: {name_or_path}")
|
||||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2)
|
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(
|
||||||
|
args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Diffusers model is loaded to CPU
|
# Diffusers model is loaded to CPU
|
||||||
print(f"load Diffusers pretrained models: {name_or_path}")
|
print(f"load Diffusers pretrained models: {name_or_path}")
|
||||||
@@ -3281,7 +3296,10 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
|||||||
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||||
|
|
||||||
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
|
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
|
||||||
args, weight_dtype, accelerator.device if args.lowram else "cpu", unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2
|
args,
|
||||||
|
weight_dtype,
|
||||||
|
accelerator.device if args.lowram else "cpu",
|
||||||
|
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# work on low-ram device
|
# work on low-ram device
|
||||||
@@ -3595,7 +3613,17 @@ SCHEDLER_SCHEDULE = "scaled_linear"
|
|||||||
|
|
||||||
|
|
||||||
def sample_images(
|
def sample_images(
|
||||||
accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None, controlnet=None
|
accelerator,
|
||||||
|
args: argparse.Namespace,
|
||||||
|
epoch,
|
||||||
|
steps,
|
||||||
|
device,
|
||||||
|
vae,
|
||||||
|
tokenizer,
|
||||||
|
text_encoder,
|
||||||
|
unet,
|
||||||
|
prompt_replacement=None,
|
||||||
|
controlnet=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
||||||
@@ -3690,7 +3718,7 @@ def sample_images(
|
|||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
clip_skip=args.clip_skip,
|
clip_skip=args.clip_skip,
|
||||||
)
|
)
|
||||||
pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する
|
pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する
|
||||||
pipeline.to(device)
|
pipeline.to(device)
|
||||||
|
|
||||||
save_dir = args.output_dir + "/sample"
|
save_dir = args.output_dir + "/sample"
|
||||||
@@ -3765,7 +3793,6 @@ def sample_images(
|
|||||||
controlnet_image = m.group(1)
|
controlnet_image = m.group(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
except ValueError as ex:
|
except ValueError as ex:
|
||||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||||
print(ex)
|
print(ex)
|
||||||
|
|||||||
Reference in New Issue
Block a user