mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
update dataset to return size, refactor ctrlnet ds
This commit is contained in:
@@ -103,6 +103,9 @@ class ImageInfo:
|
||||
self.latents_flipped: torch.Tensor = None
|
||||
self.latents_npz: str = None
|
||||
self.latents_npz_flipped: str = None
|
||||
self.latents_original_size: Tuple[int, int] = None # original image size, not latents size
|
||||
self.latents_crop_left_top: Tuple[int, int] = None # original image crop left top, not latents crop left top
|
||||
self.cond_img_path: str = None
|
||||
|
||||
|
||||
class BucketManager:
|
||||
@@ -171,6 +174,7 @@ class BucketManager:
|
||||
def select_bucket(self, image_width, image_height):
|
||||
aspect_ratio = image_width / image_height
|
||||
if not self.no_upscale:
|
||||
# 拡大および縮小を行う
|
||||
# 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
|
||||
reso = (image_width, image_height)
|
||||
if reso in self.predefined_resos_set:
|
||||
@@ -189,6 +193,7 @@ class BucketManager:
|
||||
resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5))
|
||||
# print("use predef", image_width, image_height, reso, resized_size)
|
||||
else:
|
||||
# 縮小のみを行う
|
||||
if image_width * image_height > self.max_area:
|
||||
# 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
|
||||
resized_width = math.sqrt(self.max_area * aspect_ratio)
|
||||
@@ -238,41 +243,40 @@ class BucketBatchIndex(NamedTuple):
|
||||
class AugHelper:
|
||||
def __init__(self):
|
||||
# prepare all possible augmentators
|
||||
color_aug_method = albu.OneOf(
|
||||
self.color_aug_method = albu.OneOf(
|
||||
[
|
||||
albu.HueSaturationValue(8, 0, 0, p=0.5),
|
||||
albu.RandomGamma((95, 105), p=0.5),
|
||||
],
|
||||
p=0.33,
|
||||
)
|
||||
flip_aug_method = albu.HorizontalFlip(p=0.5)
|
||||
|
||||
# key: (use_color_aug, use_flip_aug)
|
||||
self.augmentors = {
|
||||
(True, True): albu.Compose(
|
||||
[
|
||||
color_aug_method,
|
||||
flip_aug_method,
|
||||
],
|
||||
p=1.0,
|
||||
),
|
||||
(True, False): albu.Compose(
|
||||
[
|
||||
color_aug_method,
|
||||
],
|
||||
p=1.0,
|
||||
),
|
||||
(False, True): albu.Compose(
|
||||
[
|
||||
flip_aug_method,
|
||||
],
|
||||
p=1.0,
|
||||
),
|
||||
(False, False): None,
|
||||
}
|
||||
# self.augmentors = {
|
||||
# (True, True): albu.Compose(
|
||||
# [
|
||||
# color_aug_method,
|
||||
# flip_aug_method,
|
||||
# ],
|
||||
# p=1.0,
|
||||
# ),
|
||||
# (True, False): albu.Compose(
|
||||
# [
|
||||
# color_aug_method,
|
||||
# ],
|
||||
# p=1.0,
|
||||
# ),
|
||||
# (False, True): albu.Compose(
|
||||
# [
|
||||
# flip_aug_method,
|
||||
# ],
|
||||
# p=1.0,
|
||||
# ),
|
||||
# (False, False): None,
|
||||
# }
|
||||
|
||||
def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
|
||||
return self.augmentors[(use_color_aug, use_flip_aug)]
|
||||
def get_augmentor(self, use_color_aug: bool) -> Optional[albu.Compose]:
|
||||
return self.color_aug_method if use_color_aug else None
|
||||
|
||||
|
||||
class BaseSubset:
|
||||
@@ -454,10 +458,16 @@ class ControlNetSubset(BaseSubset):
|
||||
|
||||
class BaseDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool
|
||||
self,
|
||||
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
|
||||
max_token_length: int,
|
||||
resolution: Optional[Tuple[int, int]],
|
||||
debug_dataset: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
|
||||
|
||||
self.max_token_length = max_token_length
|
||||
# width/height is used when enable_bucket==False
|
||||
self.width, self.height = (None, None) if resolution is None else resolution
|
||||
@@ -478,7 +488,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.bucket_no_upscale = None
|
||||
self.bucket_info = None # for metadata
|
||||
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||
self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2
|
||||
|
||||
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
||||
|
||||
@@ -594,48 +604,49 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
return caption
|
||||
|
||||
def get_input_ids(self, caption):
|
||||
input_ids = self.tokenizer(
|
||||
def get_input_ids(self, caption, tokenizer=None):
|
||||
if tokenizer is None:
|
||||
tokenizer = self.tokenizers[0]
|
||||
|
||||
input_ids = tokenizer(
|
||||
caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt"
|
||||
).input_ids
|
||||
|
||||
if self.tokenizer_max_length > self.tokenizer.model_max_length:
|
||||
if self.tokenizer_max_length > tokenizer.model_max_length:
|
||||
input_ids = input_ids.squeeze(0)
|
||||
iids_list = []
|
||||
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
||||
if tokenizer.pad_token_id == tokenizer.eos_token_id:
|
||||
# v1
|
||||
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
||||
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
||||
for i in range(
|
||||
1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2
|
||||
1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2
|
||||
): # (1, 152, 75)
|
||||
ids_chunk = (
|
||||
input_ids[0].unsqueeze(0),
|
||||
input_ids[i : i + self.tokenizer.model_max_length - 2],
|
||||
input_ids[i : i + tokenizer.model_max_length - 2],
|
||||
input_ids[-1].unsqueeze(0),
|
||||
)
|
||||
ids_chunk = torch.cat(ids_chunk)
|
||||
iids_list.append(ids_chunk)
|
||||
else:
|
||||
# v2
|
||||
# v2 or SDXL
|
||||
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
|
||||
for i in range(
|
||||
1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2
|
||||
):
|
||||
for i in range(1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
|
||||
ids_chunk = (
|
||||
input_ids[0].unsqueeze(0), # BOS
|
||||
input_ids[i : i + self.tokenizer.model_max_length - 2],
|
||||
input_ids[i : i + tokenizer.model_max_length - 2],
|
||||
input_ids[-1].unsqueeze(0),
|
||||
) # PAD or EOS
|
||||
ids_chunk = torch.cat(ids_chunk)
|
||||
|
||||
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
|
||||
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
|
||||
if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
|
||||
ids_chunk[-1] = self.tokenizer.eos_token_id
|
||||
if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
|
||||
ids_chunk[-1] = tokenizer.eos_token_id
|
||||
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
|
||||
if ids_chunk[1] == self.tokenizer.pad_token_id:
|
||||
ids_chunk[1] = self.tokenizer.eos_token_id
|
||||
if ids_chunk[1] == tokenizer.pad_token_id:
|
||||
ids_chunk[1] = tokenizer.eos_token_id
|
||||
|
||||
iids_list.append(ids_chunk)
|
||||
|
||||
@@ -755,46 +766,58 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
img = np.array(image, np.uint8)
|
||||
return img
|
||||
|
||||
def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size, cond_img = None):
|
||||
def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size):
|
||||
image_height, image_width = image.shape[0:2]
|
||||
|
||||
if image_width != resized_size[0] or image_height != resized_size[1]:
|
||||
# リサイズする
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||
if exists(cond_img):
|
||||
cond_img = cv2.resize(cond_img, resized_size, interpolation=cv2.INTER_AREA)
|
||||
|
||||
image_height, image_width = image.shape[0:2]
|
||||
original_size = (image_width, image_height)
|
||||
|
||||
crop_left_top = (0, 0)
|
||||
if image_width > reso[0]:
|
||||
trim_size = image_width - reso[0]
|
||||
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
||||
# print("w", trim_size, p)
|
||||
image = image[:, p : p + reso[0]]
|
||||
if exists(cond_img):
|
||||
cond_img = cond_img[:, p : p + reso[0]]
|
||||
crop_left_top = (p, 0)
|
||||
if image_height > reso[1]:
|
||||
trim_size = image_height - reso[1]
|
||||
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
||||
# print("h", trim_size, p)
|
||||
image = image[p : p + reso[1]]
|
||||
if exists(cond_img):
|
||||
cond_img = cond_img[p : p + reso[1]]
|
||||
crop_left_top = (crop_left_top[0], p)
|
||||
|
||||
assert (
|
||||
image.shape[0] == reso[1] and image.shape[1] == reso[0]
|
||||
), f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||
|
||||
if exists(cond_img):
|
||||
assert (
|
||||
cond_img.shape[0] == reso[1] and cond_img.shape[1] == reso[0]
|
||||
), f"internal error, illegal trimmed size: {cond_img.shape}, {reso}"
|
||||
return image, cond_img
|
||||
|
||||
return image
|
||||
return image, original_size, crop_left_top
|
||||
|
||||
def is_latent_cacheable(self):
|
||||
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
||||
|
||||
def is_disk_cached_latents_is_expected(self, reso, npz_path, flipped_npz_path):
|
||||
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
|
||||
|
||||
for npath in [npz_path, flipped_npz_path]:
|
||||
if npath is None:
|
||||
continue
|
||||
if not os.path.exists(npath):
|
||||
return False
|
||||
|
||||
npz = np.load(npath)
|
||||
if "latents" not in npz or "original_size" not in npz or "crop_left_top" not in npz: # old ver?
|
||||
return False
|
||||
|
||||
cached_latents = npz["latents"]
|
||||
|
||||
if cached_latents.shape[1:3] != expected_latents_size:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
# ちょっと速くした
|
||||
print("caching latents.")
|
||||
@@ -811,38 +834,26 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
|
||||
if info.latents_npz is not None:
|
||||
info.latents = self.load_latents_from_npz(info, False)
|
||||
info.latents, info.latents_original_size, info.latents_crop_left_top = self.load_latents_from_npz(info, False)
|
||||
info.latents = torch.FloatTensor(info.latents)
|
||||
|
||||
# might be None, but that's ok because check is done in dataset
|
||||
info.latents_flipped = self.load_latents_from_npz(info, True)
|
||||
info.latents_flipped, _, _ = self.load_latents_from_npz(info, True) # might be None
|
||||
if info.latents_flipped is not None:
|
||||
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
||||
continue
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if cache_to_disk:
|
||||
# TODO: refactor to unify with FineTuningDataset
|
||||
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
|
||||
info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz"
|
||||
if not is_main_process:
|
||||
continue
|
||||
|
||||
cache_available = False
|
||||
expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意
|
||||
if os.path.exists(info.latents_npz):
|
||||
cached_latents = np.load(info.latents_npz)["arr_0"]
|
||||
if cached_latents.shape[1:3] == expected_latents_size:
|
||||
cache_available = True
|
||||
cache_available = self.is_disk_cached_latents_is_expected(
|
||||
info.bucket_reso, info.latents_npz, info.latents_npz_flipped if self.flip_aug else None
|
||||
)
|
||||
|
||||
if subset.flip_aug:
|
||||
cache_available = False
|
||||
if os.path.exists(info.latents_npz_flipped):
|
||||
cached_latents_flipped = np.load(info.latents_npz_flipped)["arr_0"]
|
||||
if cached_latents_flipped.shape[1:3] == expected_latents_size:
|
||||
cache_available = True
|
||||
|
||||
if cache_available:
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
# if last member of batch has different resolution, flush the batch
|
||||
@@ -868,10 +879,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
images = []
|
||||
for info in batch:
|
||||
image = self.load_image(info.absolute_path)
|
||||
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
|
||||
image, original_size, crop_left_top = self.trim_and_resize_if_required(
|
||||
subset, image, info.bucket_reso, info.resized_size
|
||||
)
|
||||
image = self.image_transforms(image)
|
||||
images.append(image)
|
||||
|
||||
info.latents_original_size = original_size
|
||||
info.latents_crop_left_top = crop_left_top
|
||||
|
||||
img_tensors = torch.stack(images, dim=0)
|
||||
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
|
||||
|
||||
@@ -879,7 +895,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
for info, latent in zip(batch, latents):
|
||||
if cache_to_disk:
|
||||
np.savez(info.latents_npz, latent.float().numpy())
|
||||
np.savez(
|
||||
info.latents_npz,
|
||||
latents=latent.float().numpy(),
|
||||
original_size=np.array(info.latents_original_size),
|
||||
crop_left_top=np.array(info.latents_crop_left_top),
|
||||
)
|
||||
else:
|
||||
info.latents = latent
|
||||
|
||||
@@ -888,7 +909,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
for info, latent in zip(batch, latents):
|
||||
if cache_to_disk:
|
||||
np.savez(info.latents_npz_flipped, latent.float().numpy())
|
||||
np.savez(
|
||||
info.latents_npz_flipped,
|
||||
latents=latent.float().numpy(),
|
||||
original_size=np.array(info.latents_original_size),
|
||||
crop_left_top=np.array(info.latents_crop_left_top), # reverse horizontally when use flipped latents
|
||||
)
|
||||
else:
|
||||
info.latents_flipped = latent
|
||||
|
||||
@@ -961,8 +987,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
|
||||
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
|
||||
if npz_file is None:
|
||||
return None
|
||||
return np.load(npz_file)["arr_0"]
|
||||
return None, None, None
|
||||
|
||||
npz = np.load(npz_file)
|
||||
latents = npz["latents"]
|
||||
original_size = npz["original_size"].tolist()
|
||||
crop_left_top = npz["crop_left_top"].tolist()
|
||||
return latents, original_size, crop_left_top
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
@@ -975,21 +1006,35 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
loss_weights = []
|
||||
captions = []
|
||||
input_ids_list = []
|
||||
input_ids2_list = []
|
||||
latents_list = []
|
||||
images = []
|
||||
original_sizes_hw = []
|
||||
crop_top_lefts = []
|
||||
target_sizes_hw = []
|
||||
flippeds = [] # 変数名が微妙
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
||||
|
||||
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
|
||||
|
||||
# image/latentsを処理する
|
||||
if image_info.latents is not None: # cache_latents=Trueの場合
|
||||
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
|
||||
original_size = image_info.latents_original_size
|
||||
crop_left_top = image_info.latents_crop_left_top # calc values later if flipped
|
||||
if not flipped:
|
||||
latents = image_info.latents
|
||||
else:
|
||||
latents = image_info.latents_flipped
|
||||
|
||||
image = None
|
||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
|
||||
latents, original_size, crop_left_top = self.load_latents_from_npz(image_info, flipped)
|
||||
latents = torch.FloatTensor(latents)
|
||||
|
||||
image = None
|
||||
else:
|
||||
# 画像を読み込み、必要ならcropする
|
||||
@@ -997,7 +1042,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
im_h, im_w = img.shape[0:2]
|
||||
|
||||
if self.enable_bucket:
|
||||
img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
|
||||
img, original_size, crop_left_top = self.trim_and_resize_if_required(
|
||||
subset, img, image_info.bucket_reso, image_info.resized_size
|
||||
)
|
||||
else:
|
||||
if face_cx > 0: # 顔位置情報あり
|
||||
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
|
||||
@@ -1017,17 +1064,33 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
im_h == self.height and im_w == self.width
|
||||
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
|
||||
original_size = [im_w, im_h]
|
||||
crop_left_top = [0, 0]
|
||||
|
||||
# augmentation
|
||||
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
|
||||
aug = self.aug_helper.get_augmentor(subset.color_aug)
|
||||
if aug is not None:
|
||||
img = aug(image=img)["image"]
|
||||
|
||||
if flipped:
|
||||
img = img[:, ::-1, :].copy() # copy to avoid negative stride problem
|
||||
|
||||
latents = None
|
||||
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
|
||||
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
|
||||
|
||||
if flipped:
|
||||
crop_left_top = (original_size[0] - crop_left_top[0] - target_size[0], crop_left_top[1])
|
||||
|
||||
original_sizes_hw.append((original_size[1], original_size[0]))
|
||||
crop_top_lefts.append((crop_left_top[1], crop_left_top[0]))
|
||||
target_sizes_hw.append((target_size[1], target_size[0]))
|
||||
flippeds.append(flipped)
|
||||
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
if self.XTI_layers:
|
||||
caption_layer = []
|
||||
@@ -1039,22 +1102,33 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
captions.append(caption_layer)
|
||||
else:
|
||||
captions.append(caption)
|
||||
|
||||
if not self.token_padding_disabled: # this option might be omitted in future
|
||||
if self.XTI_layers:
|
||||
token_caption = self.get_input_ids(caption_layer)
|
||||
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
||||
else:
|
||||
token_caption = self.get_input_ids(caption)
|
||||
token_caption = self.get_input_ids(caption, self.tokenizers[0])
|
||||
input_ids_list.append(token_caption)
|
||||
|
||||
if len(self.tokenizers) > 1:
|
||||
if self.XTI_layers:
|
||||
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
|
||||
else:
|
||||
token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||
input_ids2_list.append(token_caption2)
|
||||
|
||||
example = {}
|
||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||
|
||||
if self.token_padding_disabled:
|
||||
# padding=True means pad in the batch
|
||||
example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||
example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||
if len(self.tokenizers) > 1:
|
||||
# following may not work in SDXL, keep the line for future update
|
||||
example["input_ids2"] = self.tokenizer[1](captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||
else:
|
||||
# batch processing seems to be good
|
||||
example["input_ids"] = torch.stack(input_ids_list)
|
||||
example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None
|
||||
|
||||
if images[0] is not None:
|
||||
images = torch.stack(images)
|
||||
@@ -1066,6 +1140,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
|
||||
example["captions"] = captions
|
||||
|
||||
example["original_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in original_sizes_hw])
|
||||
example["crop_top_lefts"] = torch.stack([torch.LongTensor(x) for x in crop_top_lefts])
|
||||
example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw])
|
||||
example["flippeds"] = flippeds
|
||||
|
||||
if self.debug_dataset:
|
||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||
return example
|
||||
@@ -1462,151 +1541,86 @@ class ControlNetDataset(BaseDataset):
|
||||
max_bucket_reso: int,
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset) -> None:
|
||||
debug_dataset,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
self.conditioning_image_data: Dict[str, ImageInfo] = {}
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
db_subsets = []
|
||||
for subset in subsets:
|
||||
db_subset = DreamBoothSubset(
|
||||
subset.image_dir,
|
||||
False,
|
||||
None,
|
||||
subset.caption_extension,
|
||||
subset.num_repeats,
|
||||
subset.shuffle_caption,
|
||||
subset.keep_tokens,
|
||||
subset.color_aug,
|
||||
subset.flip_aug,
|
||||
subset.face_crop_aug_range,
|
||||
subset.random_crop,
|
||||
subset.caption_dropout_rate,
|
||||
subset.caption_dropout_every_n_epochs,
|
||||
subset.caption_tag_dropout_rate,
|
||||
subset.token_warmup_min,
|
||||
subset.token_warmup_step,
|
||||
)
|
||||
db_subsets.append(db_subset)
|
||||
|
||||
self.dreambooth_dataset_delegate = DreamBoothDataset(
|
||||
db_subsets,
|
||||
batch_size,
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
enable_bucket,
|
||||
min_bucket_reso,
|
||||
max_bucket_reso,
|
||||
bucket_reso_steps,
|
||||
bucket_no_upscale,
|
||||
1.0,
|
||||
debug_dataset,
|
||||
)
|
||||
|
||||
# config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
|
||||
self.image_data = self.dreambooth_dataset_delegate.image_data
|
||||
self.batch_size = batch_size
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.latents_cache = None
|
||||
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
|
||||
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
|
||||
|
||||
self.num_reg_images = 0
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
assert (
|
||||
min(resolution) >= min_bucket_reso
|
||||
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
|
||||
assert (
|
||||
max(resolution) <= max_bucket_reso
|
||||
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
else:
|
||||
self.min_bucket_reso = None
|
||||
self.max_bucket_reso = None
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
def read_caption(img_path, caption_extension):
|
||||
# captionの候補ファイル名を作る
|
||||
base_name = os.path.splitext(img_path)[0]
|
||||
base_name_face_det = base_name
|
||||
tokens = base_name.split("_")
|
||||
if len(tokens) >= 5:
|
||||
base_name_face_det = "_".join(tokens[:-4])
|
||||
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
|
||||
|
||||
caption = None
|
||||
for cap_path in cap_paths:
|
||||
if os.path.isfile(cap_path):
|
||||
with open(cap_path, "rt", encoding="utf-8") as f:
|
||||
try:
|
||||
lines = f.readlines()
|
||||
except UnicodeDecodeError as e:
|
||||
print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
||||
raise e
|
||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||
caption = lines[0].strip()
|
||||
# assert all conditioning data exists
|
||||
missing_imgs = []
|
||||
cond_imgs_with_img = set()
|
||||
for image_key, info in self.dreambooth_dataset_delegate.image_data.items():
|
||||
db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key]
|
||||
subset = None
|
||||
for s in subsets:
|
||||
if s.image_dir == db_subset.image_dir:
|
||||
subset = s
|
||||
break
|
||||
return caption
|
||||
assert subset is not None, "internal error: subset not found"
|
||||
|
||||
def load_controlnet_dir(subset: ControlNetSubset):
|
||||
if not os.path.isdir(subset.image_dir):
|
||||
print(f"not directory: {subset.image_dir}")
|
||||
return [], []
|
||||
if not os.path.isdir(subset.conditioning_data_dir):
|
||||
print(f"not directory: {subset.conditioning_data_dir}")
|
||||
return [], []
|
||||
continue
|
||||
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
|
||||
img_paths = sorted(img_paths)
|
||||
conditioning_img_paths = sorted(conditioning_img_paths)
|
||||
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
print(f"found directory {subset.conditioning_data_dir} contains {len(conditioning_img_paths)} image files")
|
||||
img_basename = os.path.basename(info.absolute_path)
|
||||
ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename)
|
||||
if not os.path.exists(ctrl_img_path):
|
||||
missing_imgs.append(img_basename)
|
||||
|
||||
img_basenames = [os.path.basename(img) for img in img_paths]
|
||||
conditioning_img_basenames = [os.path.basename(img) for img in conditioning_img_paths]
|
||||
missing_imgs = []
|
||||
extra_imgs = []
|
||||
info.cond_img_path = ctrl_img_path
|
||||
cond_imgs_with_img.add(ctrl_img_path)
|
||||
|
||||
for img in img_basenames:
|
||||
if img not in conditioning_img_basenames:
|
||||
missing_imgs.append(img)
|
||||
for img in conditioning_img_basenames:
|
||||
if img not in img_basenames:
|
||||
extra_imgs.append(img)
|
||||
|
||||
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
|
||||
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
|
||||
|
||||
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
missing_captions = []
|
||||
for img_path in img_paths:
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension)
|
||||
if cap_for_img is None:
|
||||
print(f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}")
|
||||
captions.append("")
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
captions.append(cap_for_img)
|
||||
|
||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
||||
|
||||
if missing_captions:
|
||||
number_of_missing_captions = len(missing_captions)
|
||||
number_of_missing_captions_to_show = 5
|
||||
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show
|
||||
|
||||
print(
|
||||
f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。"
|
||||
)
|
||||
for i, missing_caption in enumerate(missing_captions):
|
||||
if i >= number_of_missing_captions_to_show:
|
||||
print(missing_caption + f"... and {remaining_missing_captions} more")
|
||||
break
|
||||
print(missing_caption)
|
||||
return img_paths, conditioning_img_paths, captions
|
||||
|
||||
print("prepare images.")
|
||||
num_train_images = 0
|
||||
extra_imgs = []
|
||||
for subset in subsets:
|
||||
if subset.num_repeats < 1:
|
||||
print(
|
||||
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
|
||||
)
|
||||
continue
|
||||
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
|
||||
extra_imgs.extend(
|
||||
[cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img]
|
||||
)
|
||||
|
||||
if subset in self.subsets:
|
||||
print(
|
||||
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します"
|
||||
)
|
||||
continue
|
||||
|
||||
img_paths, conditioning_img_paths, captions = load_controlnet_dir(subset)
|
||||
if len(img_paths) < 1:
|
||||
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
|
||||
continue
|
||||
|
||||
num_train_images += subset.num_repeats * len(img_paths)
|
||||
|
||||
for img_path, cond_img_path, caption in zip(img_paths, conditioning_img_paths, captions):
|
||||
info = ImageInfo(img_path, subset.num_repeats, caption, False, img_path)
|
||||
setattr(info, "cond_img_path", cond_img_path)
|
||||
self.register_image(info, subset)
|
||||
|
||||
subset.img_count = len(img_paths)
|
||||
self.subsets.append(subset)
|
||||
|
||||
print(f"{num_train_images} train images with repeating.")
|
||||
self.num_train_images = num_train_images
|
||||
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
|
||||
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
|
||||
|
||||
self.conditioning_image_transforms = transforms.Compose(
|
||||
[
|
||||
@@ -1614,88 +1628,58 @@ class ControlNetDataset(BaseDataset):
|
||||
]
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
||||
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
||||
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
||||
def make_buckets(self):
|
||||
self.dreambooth_dataset_delegate.make_buckets()
|
||||
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
|
||||
self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices
|
||||
|
||||
def __len__(self):
|
||||
return self.dreambooth_dataset_delegate.__len__()
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = self.dreambooth_dataset_delegate[index]
|
||||
|
||||
bucket = self.dreambooth_dataset_delegate.bucket_manager.buckets[
|
||||
self.dreambooth_dataset_delegate.buckets_indices[index].bucket_index
|
||||
]
|
||||
bucket_batch_size = self.dreambooth_dataset_delegate.buckets_indices[index].bucket_batch_size
|
||||
image_index = self.dreambooth_dataset_delegate.buckets_indices[index].batch_index * bucket_batch_size
|
||||
|
||||
loss_weights = []
|
||||
captions = []
|
||||
input_ids_list = []
|
||||
latents_list = []
|
||||
images = []
|
||||
conditioning_images = []
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
loss_weights.append(1.0)
|
||||
for i, image_key in enumerate(bucket[image_index : image_index + bucket_batch_size]):
|
||||
image_info = self.dreambooth_dataset_delegate.image_data[image_key]
|
||||
|
||||
assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}"
|
||||
target_size_hw = example["target_sizes_hw"][i]
|
||||
original_size_hw = example["original_sizes_hw"][i]
|
||||
crop_top_left = example["crop_top_lefts"][i]
|
||||
flipped = example["flippeds"][i]
|
||||
cond_img = self.load_image(image_info.cond_img_path)
|
||||
|
||||
# image/latentsを処理する
|
||||
if image_info.latents is not None: # cache_latents=Trueの場合
|
||||
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
|
||||
image = None
|
||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
|
||||
latents = torch.FloatTensor(latents)
|
||||
image = None
|
||||
if self.dreambooth_dataset_delegate.enable_bucket:
|
||||
cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||
assert (
|
||||
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
|
||||
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
|
||||
ct, cl = crop_top_left
|
||||
h, w = target_size_hw
|
||||
cond_img = cond_img[ct : ct + h, cl : cl + w]
|
||||
else:
|
||||
# 画像を読み込み、必要ならcropする
|
||||
img = self.load_image(image_info.absolute_path)
|
||||
cond_img = self.load_image(image_info.cond_img_path)
|
||||
im_h, im_w = img.shape[0:2]
|
||||
assert (
|
||||
cond_img.shape[0] == self.height and cond_img.shape[1] == self.width
|
||||
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
|
||||
if self.enable_bucket:
|
||||
img, cond_img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size, cond_img=cond_img)
|
||||
else:
|
||||
im_h, im_w = img.shape[0:2]
|
||||
assert (
|
||||
im_h == self.height and im_w == self.width
|
||||
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
|
||||
# augmentation
|
||||
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
|
||||
if aug is not None:
|
||||
img = aug(image=img)["image"]
|
||||
|
||||
latents = None
|
||||
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
if flipped:
|
||||
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
|
||||
|
||||
cond_img = self.conditioning_image_transforms(cond_img)
|
||||
conditioning_images.append(cond_img)
|
||||
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
captions.append(caption)
|
||||
token_caption = self.get_input_ids(caption)
|
||||
input_ids_list.append(token_caption)
|
||||
|
||||
example = {}
|
||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||
|
||||
example["input_ids"] = torch.stack(input_ids_list)
|
||||
|
||||
if images[0] is not None:
|
||||
images = torch.stack(images)
|
||||
images = images.to(memory_format=torch.contiguous_format).float()
|
||||
else:
|
||||
images = None
|
||||
example["images"] = images
|
||||
|
||||
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
|
||||
example["captions"] = captions
|
||||
|
||||
if self.debug_dataset:
|
||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||
|
||||
example["conditioning_images"] = torch.stack(conditioning_images).to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
return example
|
||||
|
||||
|
||||
# behave as Dataset mock
|
||||
class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
|
||||
@@ -1773,18 +1757,42 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
example = train_dataset[idx]
|
||||
if example["latents"] is not None:
|
||||
print(f"sample has latents from npz file: {example['latents'].size()}")
|
||||
for j, (ik, cap, lw, iid) in enumerate(
|
||||
zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"])
|
||||
for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate(
|
||||
zip(
|
||||
example["image_keys"],
|
||||
example["captions"],
|
||||
example["loss_weights"],
|
||||
example["input_ids"],
|
||||
example["original_sizes_hw"],
|
||||
example["crop_top_lefts"],
|
||||
example["target_sizes_hw"],
|
||||
example["flippeds"],
|
||||
)
|
||||
):
|
||||
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
|
||||
print(
|
||||
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop left top: {crptl}, target size: {trgsz}, flipped: {flpdz}'
|
||||
)
|
||||
|
||||
if show_input_ids:
|
||||
print(f"input ids: {iid}")
|
||||
if "input_ids2" in example:
|
||||
print(f"input ids2: {example['input_ids2'][j]}")
|
||||
if example["images"] is not None:
|
||||
im = example["images"][j]
|
||||
print(f"image size: {im.size()}")
|
||||
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
||||
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
|
||||
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
|
||||
|
||||
if "conditioning_images" in example:
|
||||
cond_img = example["conditioning_images"][j]
|
||||
print(f"conditioning image size: {cond_img.size()}")
|
||||
cond_img = (cond_img.numpy() * 255.0).astype(np.uint8)
|
||||
cond_img = np.transpose(cond_img, (1, 2, 0))
|
||||
cond_img = cond_img[:, :, ::-1]
|
||||
if os.name == "nt":
|
||||
cv2.imshow("cond_img", cond_img)
|
||||
|
||||
if os.name == "nt": # only windows
|
||||
cv2.imshow("img", im)
|
||||
k = cv2.waitKey()
|
||||
@@ -2011,7 +2019,6 @@ def get_git_revision_hash() -> str:
|
||||
return "(unknown)"
|
||||
|
||||
|
||||
|
||||
# def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
# replace_attentions_for_hypernetwork()
|
||||
# # unet is not used currently, but it is here for future use
|
||||
@@ -2063,8 +2070,9 @@ def get_git_revision_hash() -> str:
|
||||
# out = self.to_out[1](out)
|
||||
# return out
|
||||
|
||||
|
||||
# diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||
def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||
def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||
if mem_eff_attn:
|
||||
print("Enable memory efficient attention for U-Net")
|
||||
unet.set_use_memory_efficient_attention(False, True)
|
||||
@@ -2080,6 +2088,7 @@ def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers, sdpa
|
||||
print("Enable SDPA for U-Net")
|
||||
unet.set_use_sdpa(True)
|
||||
|
||||
|
||||
"""
|
||||
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
||||
# vae is not used currently, but it is here for future use
|
||||
@@ -2327,7 +2336,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
|
||||
)
|
||||
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||
parser.add_argument("--sdpa", action="store_true", help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)")
|
||||
parser.add_argument(
|
||||
"--sdpa",
|
||||
action="store_true",
|
||||
help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ"
|
||||
)
|
||||
@@ -3231,7 +3244,9 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
|
||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||
if load_stable_diffusion_format:
|
||||
print(f"load StableDiffusion checkpoint: {name_or_path}")
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2)
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(
|
||||
args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2
|
||||
)
|
||||
else:
|
||||
# Diffusers model is loaded to CPU
|
||||
print(f"load Diffusers pretrained models: {name_or_path}")
|
||||
@@ -3281,7 +3296,10 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
|
||||
args, weight_dtype, accelerator.device if args.lowram else "cpu", unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2
|
||||
args,
|
||||
weight_dtype,
|
||||
accelerator.device if args.lowram else "cpu",
|
||||
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
|
||||
)
|
||||
|
||||
# work on low-ram device
|
||||
@@ -3595,7 +3613,17 @@ SCHEDLER_SCHEDULE = "scaled_linear"
|
||||
|
||||
|
||||
def sample_images(
|
||||
accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None, controlnet=None
|
||||
accelerator,
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
device,
|
||||
vae,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
unet,
|
||||
prompt_replacement=None,
|
||||
controlnet=None,
|
||||
):
|
||||
"""
|
||||
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
||||
@@ -3690,7 +3718,7 @@ def sample_images(
|
||||
requires_safety_checker=False,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する
|
||||
pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する
|
||||
pipeline.to(device)
|
||||
|
||||
save_dir = args.output_dir + "/sample"
|
||||
@@ -3765,7 +3793,6 @@ def sample_images(
|
||||
controlnet_image = m.group(1)
|
||||
continue
|
||||
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
|
||||
Reference in New Issue
Block a user