Merge branch 'sd3' into feat-hunyuan-image-2.1-inference

This commit is contained in:
Kohya S
2025-09-13 20:13:58 +09:00
18 changed files with 465 additions and 237 deletions

View File

@@ -1,12 +1,34 @@
from concurrent.futures import ThreadPoolExecutor
import gc
import time
from typing import Any, Optional, Union, Callable, Tuple
import torch
import torch.nn as nn
from library.device_utils import clean_memory_on_device, synchronize_device
# region block swap utils
# Keep these functions here for portability, and private to avoid confusion with the ones in device_utils.py
def _clean_memory_on_device(device: torch.device):
r"""
Clean memory on the specified device, will be called from training scripts.
"""
gc.collect()
# device may "cuda" or "cuda:0", so we need to check the type of device
if device.type == "cuda":
torch.cuda.empty_cache()
if device.type == "xpu":
torch.xpu.empty_cache()
if device.type == "mps":
torch.mps.empty_cache()
def _synchronize_device(device: torch.device):
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "xpu":
torch.xpu.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
@@ -68,14 +90,14 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
synchronize_device(device)
_synchronize_device(device)
# cpu to device
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
synchronize_device(device)
_synchronize_device(device)
def weighs_to_device(layer: nn.Module, device: torch.device):
@@ -141,18 +163,24 @@ class Offloader:
print(f"Waited for block {block_idx}: {time.perf_counter() - start_time:.2f}s")
# Gradient tensors
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]
class ModelOffloader(Offloader):
"""
supports forward offloading
"""
def __init__(
self,
blocks: list[nn.Module],
blocks: Union[list[nn.Module], nn.ModuleList],
blocks_to_swap: int,
device: torch.device,
supports_backward: bool = True,
debug: bool = False,
):
super().__init__(len(blocks), blocks_to_swap, device, debug)
@@ -176,7 +204,9 @@ class ModelOffloader(Offloader):
for handle in self.remove_handles:
handle.remove()
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
def create_backward_hook(
self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int
) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
# -1 for 0-based index
num_blocks_propagated = self.num_blocks - block_index - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
@@ -217,8 +247,8 @@ class ModelOffloader(Offloader):
b.to(self.device) # move block to device first. this makes sure that buffers (non weights) are on the device
weighs_to_device(b, "cpu") # make sure weights are on cpu
synchronize_device(self.device)
clean_memory_on_device(self.device)
_synchronize_device(self.device)
_clean_memory_on_device(self.device)
def wait_for_block(self, block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0: