mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
fix: revert constructor signature update
This commit is contained in:
@@ -89,8 +89,7 @@ class Offloader:
|
||||
common offloading class
|
||||
"""
|
||||
|
||||
def __init__(self, block_type: str, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
self.block_type = block_type
|
||||
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
self.num_blocks = num_blocks
|
||||
self.blocks_to_swap = blocks_to_swap
|
||||
self.device = device
|
||||
@@ -110,16 +109,12 @@ class Offloader:
|
||||
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
||||
if self.debug:
|
||||
start_time = time.perf_counter()
|
||||
print(
|
||||
f"[{self.block_type}] Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}"
|
||||
)
|
||||
print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
|
||||
|
||||
self.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||
|
||||
if self.debug:
|
||||
print(
|
||||
f"[{self.block_type}] Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter() - start_time:.2f}s"
|
||||
)
|
||||
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter() - start_time:.2f}s")
|
||||
return bidx_to_cpu, bidx_to_cuda # , event
|
||||
|
||||
block_to_cpu = blocks[block_idx_to_cpu]
|
||||
@@ -134,7 +129,7 @@ class Offloader:
|
||||
return
|
||||
|
||||
if self.debug:
|
||||
print(f"[{self.block_type}] Wait for block {block_idx}")
|
||||
print(f"Wait for block {block_idx}")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
future = self.futures.pop(block_idx)
|
||||
@@ -143,7 +138,7 @@ class Offloader:
|
||||
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
|
||||
|
||||
if self.debug:
|
||||
print(f"[{self.block_type}] Waited for block {block_idx}: {time.perf_counter() - start_time:.2f}s")
|
||||
print(f"Waited for block {block_idx}: {time.perf_counter() - start_time:.2f}s")
|
||||
|
||||
|
||||
class ModelOffloader(Offloader):
|
||||
@@ -152,10 +147,14 @@ class ModelOffloader(Offloader):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, blocks: list[nn.Module], blocks_to_swap: int, supports_backward: bool, device: torch.device, debug: bool = False
|
||||
self,
|
||||
blocks: list[nn.Module],
|
||||
blocks_to_swap: int,
|
||||
device: torch.device,
|
||||
supports_backward: bool = True,
|
||||
debug: bool = False,
|
||||
):
|
||||
block_type = f"{blocks[0].__class__.__name__}" if len(blocks) > 0 else "Unknown"
|
||||
super().__init__(block_type, len(blocks), blocks_to_swap, device, debug)
|
||||
super().__init__(len(blocks), blocks_to_swap, device, debug)
|
||||
|
||||
self.supports_backward = supports_backward
|
||||
self.forward_only = not supports_backward # forward only offloading: can be changed to True for inference
|
||||
@@ -208,7 +207,7 @@ class ModelOffloader(Offloader):
|
||||
return
|
||||
|
||||
if self.debug:
|
||||
print(f"[{self.block_type}] Prepare block devices before forward")
|
||||
print(f"Prepare block devices before forward")
|
||||
|
||||
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
|
||||
b.to(self.device)
|
||||
|
||||
@@ -171,10 +171,10 @@ class HYImageDiffusionTransformer(nn.Module):
|
||||
)
|
||||
|
||||
self.offloader_double = custom_offloading_utils.ModelOffloader(
|
||||
self.double_blocks, double_blocks_to_swap, supports_backward, device
|
||||
self.double_blocks, double_blocks_to_swap, device, supports_backward=supports_backward
|
||||
)
|
||||
self.offloader_single = custom_offloading_utils.ModelOffloader(
|
||||
self.single_blocks, single_blocks_to_swap, supports_backward, device
|
||||
self.single_blocks, single_blocks_to_swap, device, supports_backward=supports_backward
|
||||
)
|
||||
# , debug=True
|
||||
print(
|
||||
|
||||
Reference in New Issue
Block a user