diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 4fbea542..8699b344 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -89,8 +89,7 @@ class Offloader: common offloading class """ - def __init__(self, block_type: str, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): - self.block_type = block_type + def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): self.num_blocks = num_blocks self.blocks_to_swap = blocks_to_swap self.device = device @@ -110,16 +109,12 @@ class Offloader: def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): if self.debug: start_time = time.perf_counter() - print( - f"[{self.block_type}] Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}" - ) + print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}") self.swap_weight_devices(block_to_cpu, block_to_cuda) if self.debug: - print( - f"[{self.block_type}] Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter() - start_time:.2f}s" - ) + print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter() - start_time:.2f}s") return bidx_to_cpu, bidx_to_cuda # , event block_to_cpu = blocks[block_idx_to_cpu] @@ -134,7 +129,7 @@ class Offloader: return if self.debug: - print(f"[{self.block_type}] Wait for block {block_idx}") + print(f"Wait for block {block_idx}") start_time = time.perf_counter() future = self.futures.pop(block_idx) @@ -143,7 +138,7 @@ class Offloader: assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}" if self.debug: - print(f"[{self.block_type}] Waited for block {block_idx}: {time.perf_counter() - start_time:.2f}s") + print(f"Waited for block {block_idx}: {time.perf_counter() - start_time:.2f}s") class ModelOffloader(Offloader): @@ -152,10 +147,14 @@ class ModelOffloader(Offloader): """ def __init__( - self, blocks: list[nn.Module], blocks_to_swap: int, supports_backward: bool, device: torch.device, debug: bool = False + self, + blocks: list[nn.Module], + blocks_to_swap: int, + device: torch.device, + supports_backward: bool = True, + debug: bool = False, ): - block_type = f"{blocks[0].__class__.__name__}" if len(blocks) > 0 else "Unknown" - super().__init__(block_type, len(blocks), blocks_to_swap, device, debug) + super().__init__(len(blocks), blocks_to_swap, device, debug) self.supports_backward = supports_backward self.forward_only = not supports_backward # forward only offloading: can be changed to True for inference @@ -208,7 +207,7 @@ class ModelOffloader(Offloader): return if self.debug: - print(f"[{self.block_type}] Prepare block devices before forward") + print(f"Prepare block devices before forward") for b in blocks[0 : self.num_blocks - self.blocks_to_swap]: b.to(self.device) diff --git a/library/hunyuan_image_models.py b/library/hunyuan_image_models.py index 9847c55e..9e3a00e8 100644 --- a/library/hunyuan_image_models.py +++ b/library/hunyuan_image_models.py @@ -171,10 +171,10 @@ class HYImageDiffusionTransformer(nn.Module): ) self.offloader_double = custom_offloading_utils.ModelOffloader( - self.double_blocks, double_blocks_to_swap, supports_backward, device + self.double_blocks, double_blocks_to_swap, device, supports_backward=supports_backward ) self.offloader_single = custom_offloading_utils.ModelOffloader( - self.single_blocks, single_blocks_to_swap, supports_backward, device + self.single_blocks, single_blocks_to_swap, device, supports_backward=supports_backward ) # , debug=True print(