diff --git a/library/jpeg_xl_util.py b/library/jpeg_xl_util.py new file mode 100644 index 00000000..c2e3a393 --- /dev/null +++ b/library/jpeg_xl_util.py @@ -0,0 +1,186 @@ +# Modified from https://github.com/Fraetor/jxl_decode Original license: MIT +# Added partial read support for up to 200x speedup + +import os +from typing import List, Tuple + +class JXLBitstream: + """ + A stream of bits with methods for easy handling. + """ + + def __init__(self, file, offset: int = 0, offsets: List[List[int]] = None): + self.shift = 0 + self.bitstream = bytearray() + self.file = file + self.offset = offset + self.offsets = offsets + if self.offsets: + self.offset = self.offsets[0][1] + self.previous_data_len = 0 + self.index = 0 + self.file.seek(self.offset) + + def get_bits(self, length: int = 1) -> int: + if self.offsets and self.shift + length > self.previous_data_len + self.offsets[self.index][2]: + self.partial_to_read_length = length + if self.shift < self.previous_data_len + self.offsets[self.index][2]: + self.partial_read(0, length) + self.bitstream.extend(self.file.read(self.partial_to_read_length)) + else: + self.bitstream.extend(self.file.read(length)) + bitmask = 2**length - 1 + bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask + self.shift += length + return bits + + def partial_read(self, current_length: int, length: int) -> None: + self.previous_data_len += self.offsets[self.index][2] + to_read_length = self.previous_data_len - (self.shift + current_length) + self.bitstream.extend(self.file.read(to_read_length)) + current_length += to_read_length + self.partial_to_read_length -= to_read_length + self.index += 1 + self.file.seek(self.offsets[self.index][1]) + if self.shift + length > self.previous_data_len + self.offsets[self.index][2]: + self.partial_read(current_length, length) + + +def decode_codestream(file, offset: int = 0, offsets: List[List[int]] = None) -> Tuple[int,int]: + """ + Decodes the actual codestream. + JXL codestream specification: http://www-internal/2022/18181-1 + """ + + # Convert codestream to int within an object to get some handy methods. + codestream = JXLBitstream(file, offset=offset, offsets=offsets) + + # Skip signature + codestream.get_bits(16) + + # SizeHeader + div8 = codestream.get_bits(1) + if div8: + height = 8 * (1 + codestream.get_bits(5)) + else: + distribution = codestream.get_bits(2) + match distribution: + case 0: + height = 1 + codestream.get_bits(9) + case 1: + height = 1 + codestream.get_bits(13) + case 2: + height = 1 + codestream.get_bits(18) + case 3: + height = 1 + codestream.get_bits(30) + ratio = codestream.get_bits(3) + if div8 and not ratio: + width = 8 * (1 + codestream.get_bits(5)) + elif not ratio: + distribution = codestream.get_bits(2) + match distribution: + case 0: + width = 1 + codestream.get_bits(9) + case 1: + width = 1 + codestream.get_bits(13) + case 2: + width = 1 + codestream.get_bits(18) + case 3: + width = 1 + codestream.get_bits(30) + else: + match ratio: + case 1: + width = height + case 2: + width = (height * 12) // 10 + case 3: + width = (height * 4) // 3 + case 4: + width = (height * 3) // 2 + case 5: + width = (height * 16) // 9 + case 6: + width = (height * 5) // 4 + case 7: + width = (height * 2) // 1 + return width, height + + +def decode_container(file) -> Tuple[int,int]: + """ + Parses the ISOBMFF container, extracts the codestream, and decodes it. + JXL container specification: http://www-internal/2022/18181-2 + """ + + def parse_box(file, file_start: int) -> dict: + file.seek(file_start) + LBox = int.from_bytes(file.read(4), "big") + XLBox = None + if 1 < LBox <= 8: + raise ValueError(f"Invalid LBox at byte {file_start}.") + if LBox == 1: + file.seek(file_start + 8) + XLBox = int.from_bytes(file.read(8), "big") + if XLBox <= 16: + raise ValueError(f"Invalid XLBox at byte {file_start}.") + if XLBox: + header_length = 16 + box_length = XLBox + else: + header_length = 8 + if LBox == 0: + box_length = os.fstat(file.fileno()).st_size - file_start + else: + box_length = LBox + file.seek(file_start + 4) + box_type = file.read(4) + file.seek(file_start) + return { + "length": box_length, + "type": box_type, + "offset": header_length, + } + + file.seek(0) + # Reject files missing required boxes. These two boxes are required to be at + # the start and contain no values, so we can manually check there presence. + # Signature box. (Redundant as has already been checked.) + if file.read(12) != bytes.fromhex("0000000C 4A584C20 0D0A870A"): + raise ValueError("Invalid signature box.") + # File Type box. + if file.read(20) != bytes.fromhex( + "00000014 66747970 6A786C20 00000000 6A786C20" + ): + raise ValueError("Invalid file type box.") + + offset = 0 + offsets = [] + data_offset_not_found = True + container_pointer = 32 + file_size = os.fstat(file.fileno()).st_size + while data_offset_not_found: + box = parse_box(file, container_pointer) + match box["type"]: + case b"jxlc": + offset = container_pointer + box["offset"] + data_offset_not_found = False + case b"jxlp": + file.seek(container_pointer + box["offset"]) + index = int.from_bytes(file.read(4), "big") + offsets.append([index, container_pointer + box["offset"] + 4, box["length"] - box["offset"] - 4]) + container_pointer += box["length"] + if container_pointer >= file_size: + data_offset_not_found = False + + if offsets: + offsets.sort(key=lambda i: i[0]) + file.seek(0) + + return decode_codestream(file, offset=offset, offsets=offsets) + + +def get_jxl_size(path: str) -> Tuple[int,int]: + with open(path, "rb") as file: + if file.read(2) == bytes.fromhex("FF0A"): + return decode_codestream(file) + return decode_container(file) diff --git a/library/train_util.py b/library/train_util.py index 1ed1d3c2..53799081 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -113,14 +113,16 @@ except: # JPEG-XL on Linux try: from jxlpy import JXLImagePlugin + from library.jpeg_xl_util import get_jxl_size IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) except: pass -# JPEG-XL on Windows +# JPEG-XL on Linux and Windows try: import pillow_jxl + from library.jpeg_xl_util import get_jxl_size IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) except: @@ -1463,6 +1465,8 @@ class BaseDataset(torch.utils.data.Dataset): ) def get_image_size(self, image_path): + if image_path.endswith(".jxl") or image_path.endswith(".JXL"): + return get_jxl_size(image_path) # return imagesize.get(image_path) image_size = imagesize.get(image_path) if image_size[0] <= 0: