mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Refactor block swapping to utilize custom offloading utilities
This commit is contained in:
216
library/custom_offloading_utils.py
Normal file
216
library/custom_offloading_utils.py
Normal file
@@ -0,0 +1,216 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from library.device_utils import clean_memory_on_device
|
||||
|
||||
|
||||
def synchronize_device(device: torch.device):
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "xpu":
|
||||
torch.xpu.synchronize()
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
|
||||
|
||||
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream)
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
stream.synchronize()
|
||||
|
||||
# cpu to cuda
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
stream.synchronize()
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
|
||||
def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
"""
|
||||
not tested
|
||||
"""
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
|
||||
# device to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
synchronize_device()
|
||||
|
||||
# cpu to device
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
synchronize_device()
|
||||
|
||||
|
||||
def weighs_to_device(layer: nn.Module, device: torch.device):
|
||||
for module in layer.modules():
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data = module.weight.data.to(device, non_blocking=True)
|
||||
|
||||
|
||||
class Offloader:
|
||||
"""
|
||||
common offloading class
|
||||
"""
|
||||
|
||||
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
self.num_blocks = num_blocks
|
||||
self.blocks_to_swap = blocks_to_swap
|
||||
self.device = device
|
||||
self.debug = debug
|
||||
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
self.futures = {}
|
||||
self.cuda_available = device.type == "cuda"
|
||||
|
||||
def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
|
||||
if self.cuda_available:
|
||||
swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||
else:
|
||||
swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
|
||||
|
||||
def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
|
||||
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
||||
if self.debug:
|
||||
start_time = time.perf_counter()
|
||||
print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
|
||||
|
||||
self.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||
|
||||
if self.debug:
|
||||
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
|
||||
return bidx_to_cpu, bidx_to_cuda # , event
|
||||
|
||||
block_to_cpu = blocks[block_idx_to_cpu]
|
||||
block_to_cuda = blocks[block_idx_to_cuda]
|
||||
|
||||
self.futures[block_idx_to_cuda] = self.thread_pool.submit(
|
||||
move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
|
||||
)
|
||||
|
||||
def _wait_blocks_move(self, block_idx):
|
||||
if block_idx not in self.futures:
|
||||
return
|
||||
|
||||
if self.debug:
|
||||
print(f"Wait for block {block_idx}")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
future = self.futures.pop(block_idx)
|
||||
_, bidx_to_cuda = future.result()
|
||||
|
||||
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
|
||||
|
||||
if self.debug:
|
||||
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
|
||||
|
||||
|
||||
class TrainOffloader(Offloader):
|
||||
"""
|
||||
supports backward offloading
|
||||
"""
|
||||
|
||||
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
super().__init__(num_blocks, blocks_to_swap, device, debug)
|
||||
self.hook_added = set()
|
||||
|
||||
def create_grad_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
|
||||
if block_index in self.hook_added:
|
||||
return None
|
||||
self.hook_added.add(block_index)
|
||||
|
||||
# -1 for 0-based index, -1 for current block is not fully backpropagated yet
|
||||
num_blocks_propagated = self.num_blocks - block_index - 2
|
||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
|
||||
waiting = block_index > 0 and block_index <= self.blocks_to_swap
|
||||
|
||||
if not swapping and not waiting:
|
||||
return None
|
||||
|
||||
# create hook
|
||||
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
|
||||
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
|
||||
block_idx_to_wait = block_index - 1
|
||||
|
||||
if self.debug:
|
||||
print(
|
||||
f"Backward: Created grad hook for block {block_index} with {block_idx_to_cpu}, {block_idx_to_cuda}, {block_idx_to_wait}"
|
||||
)
|
||||
if swapping:
|
||||
|
||||
def grad_hook(tensor: torch.Tensor):
|
||||
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
||||
|
||||
return grad_hook
|
||||
|
||||
else:
|
||||
|
||||
def grad_hook(tensor: torch.Tensor):
|
||||
self._wait_blocks_move(block_idx_to_wait)
|
||||
|
||||
return grad_hook
|
||||
|
||||
|
||||
class ModelOffloader(Offloader):
|
||||
"""
|
||||
supports forward offloading
|
||||
"""
|
||||
|
||||
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
super().__init__(num_blocks, blocks_to_swap, device, debug)
|
||||
|
||||
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
|
||||
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
|
||||
b.to(self.device)
|
||||
weighs_to_device(b, self.device) # make sure weights are on device
|
||||
|
||||
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
|
||||
b.to(self.device) # move block to device first
|
||||
weighs_to_device(b, "cpu") # make sure weights are on cpu
|
||||
|
||||
synchronize_device(self.device)
|
||||
clean_memory_on_device(self.device)
|
||||
|
||||
def wait_for_block(self, block_idx: int):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
self._wait_blocks_move(block_idx)
|
||||
|
||||
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
if block_idx >= self.blocks_to_swap:
|
||||
return
|
||||
block_idx_to_cpu = block_idx
|
||||
block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
|
||||
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
||||
@@ -18,6 +18,7 @@ import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from library import custom_offloading_utils
|
||||
|
||||
# USE_REENTRANT = True
|
||||
|
||||
@@ -923,7 +924,8 @@ class Flux(nn.Module):
|
||||
self.cpu_offload_checkpointing = False
|
||||
self.blocks_to_swap = None
|
||||
|
||||
self.thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
self.offloader_double = None
|
||||
self.offloader_single = None
|
||||
self.num_double_blocks = len(self.double_blocks)
|
||||
self.num_single_blocks = len(self.single_blocks)
|
||||
|
||||
@@ -963,16 +965,16 @@ class Flux(nn.Module):
|
||||
|
||||
print("FLUX: Gradient checkpointing disabled.")
|
||||
|
||||
def enable_block_swap(self, num_blocks: int):
|
||||
def enable_block_swap(self, num_blocks: int, device: torch.device):
|
||||
self.blocks_to_swap = num_blocks
|
||||
self.double_blocks_to_swap = num_blocks // 2
|
||||
self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2
|
||||
print(
|
||||
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}."
|
||||
)
|
||||
double_blocks_to_swap = num_blocks // 2
|
||||
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
|
||||
|
||||
n = 1 # async block swap. 1 is enough
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=n)
|
||||
self.offloader_double = custom_offloading_utils.ModelOffloader(self.num_double_blocks, double_blocks_to_swap, device)
|
||||
self.offloader_single = custom_offloading_utils.ModelOffloader(self.num_single_blocks, single_blocks_to_swap, device)
|
||||
print(
|
||||
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
||||
)
|
||||
|
||||
def move_to_device_except_swap_blocks(self, device: torch.device):
|
||||
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
||||
@@ -988,56 +990,11 @@ class Flux(nn.Module):
|
||||
self.double_blocks = save_double_blocks
|
||||
self.single_blocks = save_single_blocks
|
||||
|
||||
# def get_block_unit(self, index: int):
|
||||
# if index < len(self.double_blocks):
|
||||
# return (self.double_blocks[index],)
|
||||
# else:
|
||||
# index -= len(self.double_blocks)
|
||||
# index *= 2
|
||||
# return self.single_blocks[index], self.single_blocks[index + 1]
|
||||
|
||||
# def get_unit_index(self, is_double: bool, index: int):
|
||||
# if is_double:
|
||||
# return index
|
||||
# else:
|
||||
# return len(self.double_blocks) + index // 2
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
# # make: first n blocks are on cuda, and last n blocks are on cpu
|
||||
# if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
# # raise ValueError("Block swap is not enabled.")
|
||||
# return
|
||||
# for i in range(self.num_block_units - self.blocks_to_swap):
|
||||
# for b in self.get_block_unit(i):
|
||||
# b.to(self.device)
|
||||
# for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
|
||||
# for b in self.get_block_unit(i):
|
||||
# b.to("cpu")
|
||||
# clean_memory_on_device(self.device)
|
||||
|
||||
# all blocks are on device, but some weights are on cpu
|
||||
# make first n blocks weights on device, and last n blocks weights on cpu
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
# raise ValueError("Block swap is not enabled.")
|
||||
return
|
||||
|
||||
for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]:
|
||||
b.to(self.device)
|
||||
utils.weighs_to_device(b, self.device) # make sure weights are on device
|
||||
for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]:
|
||||
b.to(self.device) # move block to device first
|
||||
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
|
||||
torch.cuda.synchronize()
|
||||
clean_memory_on_device(self.device)
|
||||
|
||||
for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]:
|
||||
b.to(self.device)
|
||||
utils.weighs_to_device(b, self.device) # make sure weights are on device
|
||||
for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]:
|
||||
b.to(self.device) # move block to device first
|
||||
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
|
||||
torch.cuda.synchronize()
|
||||
clean_memory_on_device(self.device)
|
||||
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
|
||||
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -1073,59 +1030,21 @@ class Flux(nn.Module):
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
else:
|
||||
# device = self.device
|
||||
|
||||
def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda):
|
||||
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
||||
# start_time = time.perf_counter()
|
||||
# print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.")
|
||||
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
|
||||
|
||||
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
|
||||
return block_idx_to_cpu, block_idx_to_cuda # , event
|
||||
|
||||
block_to_cpu = blocks[block_idx_to_cpu]
|
||||
block_to_cuda = blocks[block_idx_to_cuda]
|
||||
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
|
||||
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
|
||||
|
||||
def wait_for_blocks_move(block_idx, ftrs):
|
||||
if block_idx not in ftrs:
|
||||
return
|
||||
# print(f"Waiting for move blocks: {block_idx}")
|
||||
# start_time = time.perf_counter()
|
||||
ftr = ftrs.pop(block_idx)
|
||||
ftr.result()
|
||||
# print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds")
|
||||
|
||||
double_futures = {}
|
||||
for block_idx, block in enumerate(self.double_blocks):
|
||||
# print(f"Double block {block_idx}")
|
||||
wait_for_blocks_move(block_idx, double_futures)
|
||||
self.offloader_double.wait_for_block(block_idx)
|
||||
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
|
||||
if block_idx < self.double_blocks_to_swap:
|
||||
block_idx_to_cpu = block_idx
|
||||
block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx
|
||||
future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda)
|
||||
double_futures[block_idx_to_cuda] = future
|
||||
self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
single_futures = {}
|
||||
for block_idx, block in enumerate(self.single_blocks):
|
||||
# print(f"Single block {block_idx}")
|
||||
wait_for_blocks_move(block_idx, single_futures)
|
||||
self.offloader_single.wait_for_block(block_idx)
|
||||
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
|
||||
if block_idx < self.single_blocks_to_swap:
|
||||
block_idx_to_cpu = block_idx
|
||||
block_idx_to_cuda = self.num_single_blocks - self.single_blocks_to_swap + block_idx
|
||||
future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda)
|
||||
single_futures[block_idx_to_cuda] = future
|
||||
self.offloader_single.submit_move_blocks(self.single_blocks, block_idx)
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user