refactor: update imports to use safetensors_utils for memory-efficient operations

This commit is contained in:
Kohya S
2025-09-13 21:07:11 +09:00
parent d831c88832
commit 4e2a80a6ca
7 changed files with 20 additions and 16 deletions

View File

@@ -173,14 +173,12 @@ class ModelOffloader(Offloader):
"""
def __init__(
self,
blocks: Union[list[nn.Module], nn.ModuleList],
blocks_to_swap: int,
device: torch.device,
supports_backward: bool = True,
debug: bool = False,
):
super().__init__(len(blocks), blocks_to_swap, device, debug)
@@ -220,7 +218,7 @@ class ModelOffloader(Offloader):
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1
def backward_hook(module, grad_input, grad_output):
def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t):
if self.debug:
print(f"Backward hook for block {block_index}")
@@ -232,7 +230,7 @@ class ModelOffloader(Offloader):
return backward_hook
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
@@ -245,7 +243,7 @@ class ModelOffloader(Offloader):
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
b.to(self.device) # move block to device first. this makes sure that buffers (non weights) are on the device
weighs_to_device(b, "cpu") # make sure weights are on cpu
weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu
_synchronize_device(self.device)
_clean_memory_on_device(self.device)
@@ -255,7 +253,7 @@ class ModelOffloader(Offloader):
return
self._wait_blocks_move(block_idx)
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int):
# check if blocks_to_swap is enabled
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
@@ -266,7 +264,10 @@ class ModelOffloader(Offloader):
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
block_idx_to_cuda = block_idx_to_cuda % self.num_blocks # this works for forward-only offloading
# this works for forward-only offloading. move upstream blocks to cuda
block_idx_to_cuda = block_idx_to_cuda % self.num_blocks
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)