mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
use bytearray and add typing hints
This commit is contained in:
@@ -1,15 +1,17 @@
|
||||
# Modified from https://github.com/Fraetor/jxl_decode
|
||||
# Added partial read support for up to 200x speedup
|
||||
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
class JXLBitstream:
|
||||
"""
|
||||
A stream of bits with methods for easy handling.
|
||||
"""
|
||||
|
||||
def __init__(self, file, offset=0, offsets=[]) -> None:
|
||||
def __init__(self, file, offset: int = 0, offsets: List[List[int]] = None):
|
||||
self.shift = 0
|
||||
self.bitstream = []
|
||||
self.bitstream = bytearray()
|
||||
self.file = file
|
||||
self.offset = offset
|
||||
self.offsets = offsets
|
||||
@@ -32,7 +34,7 @@ class JXLBitstream:
|
||||
self.shift += length
|
||||
return bits
|
||||
|
||||
def partial_read(self, current_length, length):
|
||||
def partial_read(self, current_length: int, length: int) -> None:
|
||||
self.previous_data_len += self.offsets[self.index][2]
|
||||
to_read_length = self.previous_data_len - (self.shift + current_length)
|
||||
self.bitstream += self.file.read(to_read_length)
|
||||
@@ -44,7 +46,7 @@ class JXLBitstream:
|
||||
self.partial_read(current_length, length)
|
||||
|
||||
|
||||
def decode_codestream(file, offset=0, offsets=[]):
|
||||
def decode_codestream(file, offset: int = 0, offsets: List[List[int]] = None) -> Tuple[int,int]:
|
||||
"""
|
||||
Decodes the actual codestream.
|
||||
JXL codestream specification: http://www-internal/2022/18181-1
|
||||
@@ -104,13 +106,13 @@ def decode_codestream(file, offset=0, offsets=[]):
|
||||
return width, height
|
||||
|
||||
|
||||
def decode_container(file):
|
||||
def decode_container(file) -> Tuple[int,int]:
|
||||
"""
|
||||
Parses the ISOBMFF container, extracts the codestream, and decodes it.
|
||||
JXL container specification: http://www-internal/2022/18181-2
|
||||
"""
|
||||
|
||||
def parse_box(file, file_start) -> dict:
|
||||
def parse_box(file, file_start: int) -> dict:
|
||||
file.seek(file_start)
|
||||
LBox = int.from_bytes(file.read(4), "big")
|
||||
XLBox = None
|
||||
@@ -177,7 +179,7 @@ def decode_container(file):
|
||||
return decode_codestream(file, offset=offset, offsets=offsets)
|
||||
|
||||
|
||||
def get_jxl_size(path):
|
||||
def get_jxl_size(path: str) -> Tuple[int,int]:
|
||||
with open(path, "rb") as file:
|
||||
if file.read(2) == bytes.fromhex("FF0A"):
|
||||
return decode_codestream(file)
|
||||
|
||||
Reference in New Issue
Block a user