from concurrent.futures import ThreadPoolExecutor import time from typing import Optional import torch import torch.nn as nn from library.device_utils import clean_memory_on_device def synchronize_device(device: torch.device): if device.type == "cuda": torch.cuda.synchronize() elif device.type == "xpu": torch.xpu.synchronize() elif device.type == "mps": torch.mps.synchronize() def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): assert layer_to_cpu.__class__ == layer_to_cuda.__class__ weight_swap_jobs = [] for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) torch.cuda.current_stream().synchronize() # this prevents the illegal loss value stream = torch.cuda.Stream() with torch.cuda.stream(stream): # cuda to cpu for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: cuda_data_view.record_stream(stream) module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) stream.synchronize() # cpu to cuda for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) module_to_cuda.weight.data = cuda_data_view stream.synchronize() torch.cuda.current_stream().synchronize() # this prevents the illegal loss value def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): """ not tested """ assert layer_to_cpu.__class__ == layer_to_cuda.__class__ weight_swap_jobs = [] for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) # device to cpu for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) synchronize_device() # cpu to device for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) module_to_cuda.weight.data = cuda_data_view synchronize_device() def weighs_to_device(layer: nn.Module, device: torch.device): for module in layer.modules(): if hasattr(module, "weight") and module.weight is not None: module.weight.data = module.weight.data.to(device, non_blocking=True) class Offloader: """ common offloading class """ def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): self.num_blocks = num_blocks self.blocks_to_swap = blocks_to_swap self.device = device self.debug = debug self.thread_pool = ThreadPoolExecutor(max_workers=1) self.futures = {} self.cuda_available = device.type == "cuda" def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module): if self.cuda_available: swap_weight_devices(block_to_cpu, block_to_cuda) else: swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda) def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda): def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): if self.debug: start_time = time.perf_counter() print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}") self.swap_weight_devices(block_to_cpu, block_to_cuda) if self.debug: print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") return bidx_to_cpu, bidx_to_cuda # , event block_to_cpu = blocks[block_idx_to_cpu] block_to_cuda = blocks[block_idx_to_cuda] self.futures[block_idx_to_cuda] = self.thread_pool.submit( move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda ) def _wait_blocks_move(self, block_idx): if block_idx not in self.futures: return if self.debug: print(f"Wait for block {block_idx}") start_time = time.perf_counter() future = self.futures.pop(block_idx) _, bidx_to_cuda = future.result() assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}" if self.debug: print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") class TrainOffloader(Offloader): """ supports backward offloading """ def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): super().__init__(num_blocks, blocks_to_swap, device, debug) self.hook_added = set() def create_grad_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: if block_index in self.hook_added: return None self.hook_added.add(block_index) # -1 for 0-based index, -1 for current block is not fully backpropagated yet num_blocks_propagated = self.num_blocks - block_index - 2 swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap waiting = block_index > 0 and block_index <= self.blocks_to_swap if not swapping and not waiting: return None # create hook block_idx_to_cpu = self.num_blocks - num_blocks_propagated block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated block_idx_to_wait = block_index - 1 if self.debug: print( f"Backward: Created grad hook for block {block_index} with {block_idx_to_cpu}, {block_idx_to_cuda}, {block_idx_to_wait}" ) if swapping: def grad_hook(tensor: torch.Tensor): self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) return grad_hook else: def grad_hook(tensor: torch.Tensor): self._wait_blocks_move(block_idx_to_wait) return grad_hook class ModelOffloader(Offloader): """ supports forward offloading """ def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): super().__init__(num_blocks, blocks_to_swap, device, debug) def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return for b in blocks[0 : self.num_blocks - self.blocks_to_swap]: b.to(self.device) weighs_to_device(b, self.device) # make sure weights are on device for b in blocks[self.num_blocks - self.blocks_to_swap :]: b.to(self.device) # move block to device first weighs_to_device(b, "cpu") # make sure weights are on cpu synchronize_device(self.device) clean_memory_on_device(self.device) def wait_for_block(self, block_idx: int): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return self._wait_blocks_move(block_idx) def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return if block_idx >= self.blocks_to_swap: return block_idx_to_cpu = block_idx block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)