mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 01:12:41 +00:00
Merge branch 'sd3' into feat-hunyuan-image-2.1-inference
This commit is contained in:
@@ -1,12 +1,34 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import gc
|
||||
import time
|
||||
from typing import Any, Optional, Union, Callable, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from library.device_utils import clean_memory_on_device, synchronize_device
|
||||
|
||||
# region block swap utils
|
||||
# Keep these functions here for portability, and private to avoid confusion with the ones in device_utils.py
|
||||
def _clean_memory_on_device(device: torch.device):
|
||||
r"""
|
||||
Clean memory on the specified device, will be called from training scripts.
|
||||
"""
|
||||
gc.collect()
|
||||
|
||||
# device may "cuda" or "cuda:0", so we need to check the type of device
|
||||
if device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
if device.type == "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
if device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
|
||||
|
||||
def _synchronize_device(device: torch.device):
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "xpu":
|
||||
torch.xpu.synchronize()
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
|
||||
|
||||
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
@@ -68,14 +90,14 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
synchronize_device(device)
|
||||
_synchronize_device(device)
|
||||
|
||||
# cpu to device
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
synchronize_device(device)
|
||||
_synchronize_device(device)
|
||||
|
||||
|
||||
def weighs_to_device(layer: nn.Module, device: torch.device):
|
||||
@@ -141,18 +163,24 @@ class Offloader:
|
||||
print(f"Waited for block {block_idx}: {time.perf_counter() - start_time:.2f}s")
|
||||
|
||||
|
||||
# Gradient tensors
|
||||
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]
|
||||
|
||||
|
||||
class ModelOffloader(Offloader):
|
||||
"""
|
||||
supports forward offloading
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
self,
|
||||
blocks: list[nn.Module],
|
||||
blocks: Union[list[nn.Module], nn.ModuleList],
|
||||
blocks_to_swap: int,
|
||||
device: torch.device,
|
||||
supports_backward: bool = True,
|
||||
debug: bool = False,
|
||||
|
||||
):
|
||||
super().__init__(len(blocks), blocks_to_swap, device, debug)
|
||||
|
||||
@@ -176,7 +204,9 @@ class ModelOffloader(Offloader):
|
||||
for handle in self.remove_handles:
|
||||
handle.remove()
|
||||
|
||||
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
|
||||
def create_backward_hook(
|
||||
self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int
|
||||
) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
|
||||
# -1 for 0-based index
|
||||
num_blocks_propagated = self.num_blocks - block_index - 1
|
||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
|
||||
@@ -217,8 +247,8 @@ class ModelOffloader(Offloader):
|
||||
b.to(self.device) # move block to device first. this makes sure that buffers (non weights) are on the device
|
||||
weighs_to_device(b, "cpu") # make sure weights are on cpu
|
||||
|
||||
synchronize_device(self.device)
|
||||
clean_memory_on_device(self.device)
|
||||
_synchronize_device(self.device)
|
||||
_clean_memory_on_device(self.device)
|
||||
|
||||
def wait_for_block(self, block_idx: int):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
|
||||
Reference in New Issue
Block a user