fix: revert constructor signature update

This commit is contained in:
Kohya S
2025-09-11 22:27:00 +09:00
parent 7f983c558d
commit a0f0afbb46
2 changed files with 15 additions and 16 deletions

View File

@@ -89,8 +89,7 @@ class Offloader:
common offloading class
"""
def __init__(self, block_type: str, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
self.block_type = block_type
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
self.num_blocks = num_blocks
self.blocks_to_swap = blocks_to_swap
self.device = device
@@ -110,16 +109,12 @@ class Offloader:
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
if self.debug:
start_time = time.perf_counter()
print(
f"[{self.block_type}] Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}"
)
print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
self.swap_weight_devices(block_to_cpu, block_to_cuda)
if self.debug:
print(
f"[{self.block_type}] Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter() - start_time:.2f}s"
)
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter() - start_time:.2f}s")
return bidx_to_cpu, bidx_to_cuda # , event
block_to_cpu = blocks[block_idx_to_cpu]
@@ -134,7 +129,7 @@ class Offloader:
return
if self.debug:
print(f"[{self.block_type}] Wait for block {block_idx}")
print(f"Wait for block {block_idx}")
start_time = time.perf_counter()
future = self.futures.pop(block_idx)
@@ -143,7 +138,7 @@ class Offloader:
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
if self.debug:
print(f"[{self.block_type}] Waited for block {block_idx}: {time.perf_counter() - start_time:.2f}s")
print(f"Waited for block {block_idx}: {time.perf_counter() - start_time:.2f}s")
class ModelOffloader(Offloader):
@@ -152,10 +147,14 @@ class ModelOffloader(Offloader):
"""
def __init__(
self, blocks: list[nn.Module], blocks_to_swap: int, supports_backward: bool, device: torch.device, debug: bool = False
self,
blocks: list[nn.Module],
blocks_to_swap: int,
device: torch.device,
supports_backward: bool = True,
debug: bool = False,
):
block_type = f"{blocks[0].__class__.__name__}" if len(blocks) > 0 else "Unknown"
super().__init__(block_type, len(blocks), blocks_to_swap, device, debug)
super().__init__(len(blocks), blocks_to_swap, device, debug)
self.supports_backward = supports_backward
self.forward_only = not supports_backward # forward only offloading: can be changed to True for inference
@@ -208,7 +207,7 @@ class ModelOffloader(Offloader):
return
if self.debug:
print(f"[{self.block_type}] Prepare block devices before forward")
print(f"Prepare block devices before forward")
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
b.to(self.device)

View File

@@ -171,10 +171,10 @@ class HYImageDiffusionTransformer(nn.Module):
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, double_blocks_to_swap, supports_backward, device
self.double_blocks, double_blocks_to_swap, device, supports_backward=supports_backward
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, single_blocks_to_swap, supports_backward, device
self.single_blocks, single_blocks_to_swap, device, supports_backward=supports_backward
)
# , debug=True
print(