From 9a415ba9651be2d4be7b3321ad244f9001b83738 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Thu, 27 Feb 2025 00:21:57 +0300 Subject: [PATCH 1/6] JPEG XL support --- library/jpeg_xl_util.py | 184 ++++++++++++++++++++++++++++++++++++++++ library/train_util.py | 6 +- 2 files changed, 189 insertions(+), 1 deletion(-) create mode 100644 library/jpeg_xl_util.py diff --git a/library/jpeg_xl_util.py b/library/jpeg_xl_util.py new file mode 100644 index 00000000..82da10ef --- /dev/null +++ b/library/jpeg_xl_util.py @@ -0,0 +1,184 @@ +# Modifed from https://github.com/Fraetor/jxl_decode +# Added partial read support for 200x speedup +import os + +class JXLBitstream: + """ + A stream of bits with methods for easy handling. + """ + + def __init__(self, file, offset=0, offsets=[]) -> None: + self.shift = 0 + self.bitstream = [] + self.file = file + self.offset = offset + self.offsets = offsets + if self.offsets: + self.offset = self.offsets[0][1] + self.previous_data_len = 0 + self.index = 0 + self.file.seek(self.offset) + + def get_bits(self, length: int = 1) -> int: + if self.offsets and self.shift + length > self.previous_data_len + self.offsets[self.index][2]: + self.partial_to_read_length = length + if self.shift < self.previous_data_len + self.offsets[self.index][2]: + self.partial_read(0, length) + self.bitstream += self.file.read(self.partial_to_read_length) + else: + self.bitstream += self.file.read(length) + bitmask = 2**length - 1 + bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask + self.shift += length + return bits + + def partial_read(self, readed_length, length): + self.previous_data_len += self.offsets[self.index][2] + to_read_length = self.previous_data_len - (self.shift + readed_length) + self.bitstream += self.file.read(to_read_length) + readed_length += to_read_length + self.partial_to_read_length -= to_read_length + self.index += 1 + self.file.seek(self.offsets[self.index][1]) + if self.shift + length > self.previous_data_len + self.offsets[self.index][2]: + self.partial_read(readed_length, length) + + +def decode_codestream(file, offset=0, offsets=[]): + """ + Decodes the actual codestream. + JXL codestream specification: http://www-internal/2022/18181-1 + """ + + # Convert codestream to int within an object to get some handy methods. + codestream = JXLBitstream(file, offset=offset, offsets=offsets) + + # Skip signature + codestream.get_bits(16) + + # SizeHeader + div8 = codestream.get_bits(1) + if div8: + height = 8 * (1 + codestream.get_bits(5)) + else: + distribution = codestream.get_bits(2) + match distribution: + case 0: + height = 1 + codestream.get_bits(9) + case 1: + height = 1 + codestream.get_bits(13) + case 2: + height = 1 + codestream.get_bits(18) + case 3: + height = 1 + codestream.get_bits(30) + ratio = codestream.get_bits(3) + if div8 and not ratio: + width = 8 * (1 + codestream.get_bits(5)) + elif not ratio: + distribution = codestream.get_bits(2) + match distribution: + case 0: + width = 1 + codestream.get_bits(9) + case 1: + width = 1 + codestream.get_bits(13) + case 2: + width = 1 + codestream.get_bits(18) + case 3: + width = 1 + codestream.get_bits(30) + else: + match ratio: + case 1: + width = height + case 2: + width = (height * 12) // 10 + case 3: + width = (height * 4) // 3 + case 4: + width = (height * 3) // 2 + case 5: + width = (height * 16) // 9 + case 6: + width = (height * 5) // 4 + case 7: + width = (height * 2) // 1 + return width, height + + +def decode_container(file): + """ + Parses the ISOBMFF container, extracts the codestream, and decodes it. + JXL container specification: http://www-internal/2022/18181-2 + """ + + def parse_box(file, file_start) -> dict: + file.seek(file_start) + LBox = int.from_bytes(file.read(4), "big") + XLBox = None + if 1 < LBox <= 8: + raise ValueError(f"Invalid LBox at byte {file_start}.") + if LBox == 1: + file.seek(file_start + 8) + XLBox = int.from_bytes(file.read(8), "big") + if XLBox <= 16: + raise ValueError(f"Invalid XLBox at byte {file_start}.") + if XLBox: + header_length = 16 + box_length = XLBox + else: + header_length = 8 + if LBox == 0: + box_length = os.fstat(file.fileno()).st_size - file_start + else: + box_length = LBox + file.seek(file_start + 4) + box_type = file.read(4) + file.seek(file_start) + return { + "length": box_length, + "type": box_type, + "offset": header_length, + } + + file.seek(0) + # Reject files missing required boxes. These two boxes are required to be at + # the start and contain no values, so we can manually check there presence. + # Signature box. (Redundant as has already been checked.) + if file.read(12) != bytes.fromhex("0000000C 4A584C20 0D0A870A"): + raise ValueError("Invalid signature box.") + # File Type box. + if file.read(20) != bytes.fromhex( + "00000014 66747970 6A786C20 00000000 6A786C20" + ): + raise ValueError("Invalid file type box.") + + offset = 0 + offsets = [] + data_offset_not_found = True + container_pointer = 32 + file_size = os.fstat(file.fileno()).st_size + while data_offset_not_found: + box = parse_box(file, container_pointer) + match box["type"]: + case b"jxlc": + offset = container_pointer + box["offset"] + data_offset_not_found = False + case b"jxlp": + file.seek(container_pointer + box["offset"]) + index = int.from_bytes(file.read(4), "big") + offsets.append([index, container_pointer + box["offset"] + 4, box["length"] - box["offset"] - 4]) + container_pointer += box["length"] + if container_pointer >= file_size: + data_offset_not_found = False + + if offsets: + offsets.sort(key=lambda i: i[0]) + file.seek(0) + + return decode_codestream(file, offset=offset, offsets=offsets) + + +def get_jxl_size(path): + with open(path, "rb") as file: + if file.read(2) == bytes.fromhex("FF0A"): + return decode_codestream(file) + return decode_container(file) diff --git a/library/train_util.py b/library/train_util.py index 100ef475..916b8834 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -118,14 +118,16 @@ except: # JPEG-XL on Linux try: from jxlpy import JXLImagePlugin + from library.jpeg_xl_util import get_jxl_size IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) except: pass -# JPEG-XL on Windows +# JPEG-XL on Linux and Windows try: import pillow_jxl + from library.jpeg_xl_util import get_jxl_size IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) except: @@ -1156,6 +1158,8 @@ class BaseDataset(torch.utils.data.Dataset): ) def get_image_size(self, image_path): + if image_path.endswith(".jxl"): + return get_jxl_size(image_path) return imagesize.get(image_path) def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False): From 2f69f4dbdb679ca887d3bc4438667360ae934ca9 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Thu, 27 Feb 2025 00:30:19 +0300 Subject: [PATCH 2/6] fix typo --- library/jpeg_xl_util.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/library/jpeg_xl_util.py b/library/jpeg_xl_util.py index 82da10ef..d127d44d 100644 --- a/library/jpeg_xl_util.py +++ b/library/jpeg_xl_util.py @@ -1,5 +1,5 @@ -# Modifed from https://github.com/Fraetor/jxl_decode -# Added partial read support for 200x speedup +# Modified from https://github.com/Fraetor/jxl_decode +# Added partial read support for up to 200x speedup import os class JXLBitstream: @@ -32,16 +32,16 @@ class JXLBitstream: self.shift += length return bits - def partial_read(self, readed_length, length): + def partial_read(self, current_length, length): self.previous_data_len += self.offsets[self.index][2] - to_read_length = self.previous_data_len - (self.shift + readed_length) + to_read_length = self.previous_data_len - (self.shift + current_length) self.bitstream += self.file.read(to_read_length) - readed_length += to_read_length + current_length += to_read_length self.partial_to_read_length -= to_read_length self.index += 1 self.file.seek(self.offsets[self.index][1]) if self.shift + length > self.previous_data_len + self.offsets[self.index][2]: - self.partial_read(readed_length, length) + self.partial_read(current_length, length) def decode_codestream(file, offset=0, offsets=[]): From 7e90cdd47a6019739d59beb161741891b79eeaef Mon Sep 17 00:00:00 2001 From: Disty0 Date: Mon, 17 Mar 2025 17:26:08 +0300 Subject: [PATCH 3/6] use bytearray and add typing hints --- library/jpeg_xl_util.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/library/jpeg_xl_util.py b/library/jpeg_xl_util.py index d127d44d..3f5e7f72 100644 --- a/library/jpeg_xl_util.py +++ b/library/jpeg_xl_util.py @@ -1,15 +1,17 @@ # Modified from https://github.com/Fraetor/jxl_decode # Added partial read support for up to 200x speedup + import os +from typing import List, Tuple class JXLBitstream: """ A stream of bits with methods for easy handling. """ - def __init__(self, file, offset=0, offsets=[]) -> None: + def __init__(self, file, offset: int = 0, offsets: List[List[int]] = None): self.shift = 0 - self.bitstream = [] + self.bitstream = bytearray() self.file = file self.offset = offset self.offsets = offsets @@ -32,7 +34,7 @@ class JXLBitstream: self.shift += length return bits - def partial_read(self, current_length, length): + def partial_read(self, current_length: int, length: int) -> None: self.previous_data_len += self.offsets[self.index][2] to_read_length = self.previous_data_len - (self.shift + current_length) self.bitstream += self.file.read(to_read_length) @@ -44,7 +46,7 @@ class JXLBitstream: self.partial_read(current_length, length) -def decode_codestream(file, offset=0, offsets=[]): +def decode_codestream(file, offset: int = 0, offsets: List[List[int]] = None) -> Tuple[int,int]: """ Decodes the actual codestream. JXL codestream specification: http://www-internal/2022/18181-1 @@ -104,13 +106,13 @@ def decode_codestream(file, offset=0, offsets=[]): return width, height -def decode_container(file): +def decode_container(file) -> Tuple[int,int]: """ Parses the ISOBMFF container, extracts the codestream, and decodes it. JXL container specification: http://www-internal/2022/18181-2 """ - def parse_box(file, file_start) -> dict: + def parse_box(file, file_start: int) -> dict: file.seek(file_start) LBox = int.from_bytes(file.read(4), "big") XLBox = None @@ -177,7 +179,7 @@ def decode_container(file): return decode_codestream(file, offset=offset, offsets=offsets) -def get_jxl_size(path): +def get_jxl_size(path: str) -> Tuple[int,int]: with open(path, "rb") as file: if file.read(2) == bytes.fromhex("FF0A"): return decode_codestream(file) From 564ec5fb7f6027d89b565cff1e8ed81a9e89ae07 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Mon, 17 Mar 2025 17:41:03 +0300 Subject: [PATCH 4/6] use extend instead of += --- library/jpeg_xl_util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/library/jpeg_xl_util.py b/library/jpeg_xl_util.py index 3f5e7f72..ade24a05 100644 --- a/library/jpeg_xl_util.py +++ b/library/jpeg_xl_util.py @@ -26,9 +26,9 @@ class JXLBitstream: self.partial_to_read_length = length if self.shift < self.previous_data_len + self.offsets[self.index][2]: self.partial_read(0, length) - self.bitstream += self.file.read(self.partial_to_read_length) + self.bitstream.extend(self.file.read(self.partial_to_read_length)) else: - self.bitstream += self.file.read(length) + self.bitstream.extend(self.file.read(length)) bitmask = 2**length - 1 bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask self.shift += length @@ -37,7 +37,7 @@ class JXLBitstream: def partial_read(self, current_length: int, length: int) -> None: self.previous_data_len += self.offsets[self.index][2] to_read_length = self.previous_data_len - (self.shift + current_length) - self.bitstream += self.file.read(to_read_length) + self.bitstream.extend(self.file.read(to_read_length)) current_length += to_read_length self.partial_to_read_length -= to_read_length self.index += 1 From 620a06f517032fff9842b81950795bb14c0ad361 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Mon, 17 Mar 2025 17:44:29 +0300 Subject: [PATCH 5/6] Check for uppercase file extension too --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 916b8834..3b6f7663 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1158,7 +1158,7 @@ class BaseDataset(torch.utils.data.Dataset): ) def get_image_size(self, image_path): - if image_path.endswith(".jxl"): + if image_path.endswith(".jxl") or image_path.endswith(".JXL"): return get_jxl_size(image_path) return imagesize.get(image_path) From 583ab27b3ccf9e360be07d424e8c15f90d1041ad Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 31 Mar 2025 22:02:25 +0900 Subject: [PATCH 6/6] doc: update license information in jpeg_xl_util.py --- library/jpeg_xl_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/jpeg_xl_util.py b/library/jpeg_xl_util.py index ade24a05..c2e3a393 100644 --- a/library/jpeg_xl_util.py +++ b/library/jpeg_xl_util.py @@ -1,4 +1,4 @@ -# Modified from https://github.com/Fraetor/jxl_decode +# Modified from https://github.com/Fraetor/jxl_decode Original license: MIT # Added partial read support for up to 200x speedup import os