mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix bucketing
This commit is contained in:
@@ -754,12 +754,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
img = np.array(image, np.uint8)
|
||||
return img
|
||||
|
||||
def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size):
|
||||
def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size, cond_img = None):
|
||||
image_height, image_width = image.shape[0:2]
|
||||
|
||||
if image_width != resized_size[0] or image_height != resized_size[1]:
|
||||
# リサイズする
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||
if exists(cond_img):
|
||||
cond_img = cv2.resize(cond_img, resized_size, interpolation=cv2.INTER_AREA)
|
||||
|
||||
image_height, image_width = image.shape[0:2]
|
||||
if image_width > reso[0]:
|
||||
@@ -767,15 +769,26 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
||||
# print("w", trim_size, p)
|
||||
image = image[:, p : p + reso[0]]
|
||||
if exists(cond_img):
|
||||
cond_img = cond_img[:, p : p + reso[0]]
|
||||
if image_height > reso[1]:
|
||||
trim_size = image_height - reso[1]
|
||||
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
||||
# print("h", trim_size, p)
|
||||
image = image[p : p + reso[1]]
|
||||
if exists(cond_img):
|
||||
cond_img = cond_img[p : p + reso[1]]
|
||||
|
||||
assert (
|
||||
image.shape[0] == reso[1] and image.shape[1] == reso[0]
|
||||
), f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||
|
||||
if exists(cond_img):
|
||||
assert (
|
||||
cond_img.shape[0] == reso[1] and cond_img.shape[1] == reso[0]
|
||||
), f"internal error, illegal trimmed size: {cond_img.shape}, {reso}"
|
||||
return image, cond_img
|
||||
|
||||
return image
|
||||
|
||||
def is_latent_cacheable(self):
|
||||
@@ -1617,6 +1630,8 @@ class ControlNetDataset(BaseDataset):
|
||||
subset = self.image_to_subset[image_key]
|
||||
loss_weights.append(1.0)
|
||||
|
||||
assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}"
|
||||
|
||||
# image/latentsを処理する
|
||||
if image_info.latents is not None: # cache_latents=Trueの場合
|
||||
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
|
||||
@@ -1628,10 +1643,11 @@ class ControlNetDataset(BaseDataset):
|
||||
else:
|
||||
# 画像を読み込み、必要ならcropする
|
||||
img = self.load_image(image_info.absolute_path)
|
||||
cond_img = self.load_image(image_info.cond_img_path)
|
||||
im_h, im_w = img.shape[0:2]
|
||||
|
||||
if self.enable_bucket:
|
||||
img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
|
||||
img, cond_img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size, cond_img=cond_img)
|
||||
else:
|
||||
im_h, im_w = img.shape[0:2]
|
||||
assert (
|
||||
@@ -1649,41 +1665,18 @@ class ControlNetDataset(BaseDataset):
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
if self.XTI_layers:
|
||||
caption_layer = []
|
||||
for layer in self.XTI_layers:
|
||||
token_strings_from = " ".join(self.token_strings)
|
||||
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||
caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||
caption_layer.append(caption_)
|
||||
captions.append(caption_layer)
|
||||
else:
|
||||
captions.append(caption)
|
||||
if not self.token_padding_disabled: # this option might be omitted in future
|
||||
if self.XTI_layers:
|
||||
token_caption = self.get_input_ids(caption_layer)
|
||||
else:
|
||||
token_caption = self.get_input_ids(caption)
|
||||
input_ids_list.append(token_caption)
|
||||
|
||||
assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}"
|
||||
|
||||
cond_img = self.load_image(image_info.cond_img_path)
|
||||
if self.enable_bucket:
|
||||
cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size)
|
||||
cond_img = self.conditioning_image_transforms(cond_img)
|
||||
conditioning_images.append(cond_img)
|
||||
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
captions.append(caption)
|
||||
token_caption = self.get_input_ids(caption)
|
||||
input_ids_list.append(token_caption)
|
||||
|
||||
example = {}
|
||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||
|
||||
if self.token_padding_disabled:
|
||||
# padding=True means pad in the batch
|
||||
example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||
else:
|
||||
# batch processing seems to be good
|
||||
example["input_ids"] = torch.stack(input_ids_list)
|
||||
example["input_ids"] = torch.stack(input_ids_list)
|
||||
|
||||
if images[0] is not None:
|
||||
images = torch.stack(images)
|
||||
|
||||
Reference in New Issue
Block a user