Fix image size batch for SDXL

This commit is contained in:
rockerBOO
2024-06-19 17:19:57 -04:00
parent 9637ce0c23
commit 2c646cc62b

View File

@@ -1701,24 +1701,38 @@ class BaseDataset(torch.utils.data.Dataset):
images.append(image)
latents_list.append(None)
alpha_mask_list.append(alpha_mask)
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
if not flipped:
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
else:
# crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image
crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1])
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0])))
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
flippeds.append(flipped)
else:
image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(subset, image_info, image_info.absolute_path, flipped)
images.append(image)
latents_list.append(None)
alpha_mask_list.append(alpha_mask)
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
if not flipped:
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
else:
# crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image
crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1])
if not flipped:
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
else:
# crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image
crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1])
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0])))
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
flippeds.append(flipped)
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0])))
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
flippeds.append(flipped)
# captionとtext encoder outputを処理する
caption = image_info.caption # default