faster block swap

This commit is contained in:
Kohya S
2024-11-05 21:22:42 +09:00
parent 5e32ee26a1
commit 81c0c965a2
3 changed files with 350 additions and 113 deletions

View File

@@ -7,8 +7,9 @@ from dataclasses import dataclass
import math
import os
import time
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
from library import utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -923,7 +924,8 @@ class Flux(nn.Module):
self.blocks_to_swap = None
self.thread_pool: Optional[ThreadPoolExecutor] = None
self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2
self.num_double_blocks = len(self.double_blocks)
self.num_single_blocks = len(self.single_blocks)
@property
def device(self):
@@ -963,14 +965,17 @@ class Flux(nn.Module):
def enable_block_swap(self, num_blocks: int):
self.blocks_to_swap = num_blocks
self.double_blocks_to_swap = num_blocks // 2
self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}."
)
n = 1 # async block swap. 1 is enough
# n = 2
# n = max(1, os.cpu_count() // 2)
self.thread_pool = ThreadPoolExecutor(max_workers=n)
def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap:
save_double_blocks = self.double_blocks
save_single_blocks = self.single_blocks
@@ -983,31 +988,55 @@ class Flux(nn.Module):
self.double_blocks = save_double_blocks
self.single_blocks = save_single_blocks
def get_block_unit(self, index: int):
if index < len(self.double_blocks):
return (self.double_blocks[index],)
else:
index -= len(self.double_blocks)
index *= 2
return self.single_blocks[index], self.single_blocks[index + 1]
# def get_block_unit(self, index: int):
# if index < len(self.double_blocks):
# return (self.double_blocks[index],)
# else:
# index -= len(self.double_blocks)
# index *= 2
# return self.single_blocks[index], self.single_blocks[index + 1]
def get_unit_index(self, is_double: bool, index: int):
if is_double:
return index
else:
return len(self.double_blocks) + index // 2
# def get_unit_index(self, is_double: bool, index: int):
# if is_double:
# return index
# else:
# return len(self.double_blocks) + index // 2
def prepare_block_swap_before_forward(self):
# make: first n blocks are on cuda, and last n blocks are on cpu
# # make: first n blocks are on cuda, and last n blocks are on cpu
# if self.blocks_to_swap is None or self.blocks_to_swap == 0:
# # raise ValueError("Block swap is not enabled.")
# return
# for i in range(self.num_block_units - self.blocks_to_swap):
# for b in self.get_block_unit(i):
# b.to(self.device)
# for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
# for b in self.get_block_unit(i):
# b.to("cpu")
# clean_memory_on_device(self.device)
# all blocks are on device, but some weights are on cpu
# make first n blocks weights on device, and last n blocks weights on cpu
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
# raise ValueError("Block swap is not enabled.")
return
for i in range(self.num_block_units - self.blocks_to_swap):
for b in self.get_block_unit(i):
b.to(self.device)
for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
for b in self.get_block_unit(i):
b.to("cpu")
for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]:
b.to(self.device)
utils.weighs_to_device(b, self.device) # make sure weights are on device
for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]:
b.to(self.device) # move block to device first
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
torch.cuda.synchronize()
clean_memory_on_device(self.device)
for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]:
b.to(self.device)
utils.weighs_to_device(b, self.device) # make sure weights are on device
for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]:
b.to(self.device) # move block to device first
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
torch.cuda.synchronize()
clean_memory_on_device(self.device)
def forward(
@@ -1044,27 +1073,22 @@ class Flux(nn.Module):
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
else:
futures = {}
# device = self.device
def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda):
# print(f"Moving {bidx_to_cpu} to cpu.")
for block in blocks_to_cpu:
block.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
# print(f"Moving {bidx_to_cuda} to cuda.")
for block in blocks_to_cuda:
block.to(self.device, non_blocking=True)
torch.cuda.synchronize()
def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
start_time = time.perf_counter()
# print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.")
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
return block_idx_to_cpu, block_idx_to_cuda
blocks_to_cpu = self.get_block_unit(block_idx_to_cpu)
blocks_to_cuda = self.get_block_unit(block_idx_to_cuda)
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
return block_idx_to_cpu, block_idx_to_cuda # , event
block_to_cpu = blocks[block_idx_to_cpu]
block_to_cuda = blocks[block_idx_to_cuda]
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda)
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
def wait_for_blocks_move(block_idx, ftrs):
if block_idx not in ftrs:
@@ -1073,37 +1097,35 @@ class Flux(nn.Module):
# start_time = time.perf_counter()
ftr = ftrs.pop(block_idx)
ftr.result()
# torch.cuda.synchronize()
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
# print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds")
double_futures = {}
for block_idx, block in enumerate(self.double_blocks):
# print(f"Double block {block_idx}")
unit_idx = self.get_unit_index(is_double=True, index=block_idx)
wait_for_blocks_move(unit_idx, futures)
wait_for_blocks_move(block_idx, double_futures)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if unit_idx < self.blocks_to_swap:
block_idx_to_cpu = unit_idx
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
futures[block_idx_to_cuda] = future
if block_idx < self.double_blocks_to_swap:
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx
future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda)
double_futures[block_idx_to_cuda] = future
img = torch.cat((txt, img), 1)
single_futures = {}
for block_idx, block in enumerate(self.single_blocks):
# print(f"Single block {block_idx}")
unit_idx = self.get_unit_index(is_double=False, index=block_idx)
if block_idx % 2 == 0:
wait_for_blocks_move(unit_idx, futures)
wait_for_blocks_move(block_idx, single_futures)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap:
block_idx_to_cpu = unit_idx
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
futures[block_idx_to_cuda] = future
if block_idx < self.single_blocks_to_swap:
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_single_blocks - self.blocks_to_swap + block_idx
future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda)
single_futures[block_idx_to_cuda] = future
img = img[:, txt.shape[1] :, ...]

View File

@@ -6,6 +6,7 @@ import json
import struct
import torch
import torch.nn as nn
from torchvision import transforms
from diffusers import EulerAncestralDiscreteScheduler
import diffusers.schedulers.scheduling_euler_ancestral_discrete
@@ -93,6 +94,225 @@ def setup_logging(args=None, log_level=None, reset=False):
# region PyTorch utils
# def swap_weights(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
# # print(f"Swapping {layer_to_cpu.__class__.__name__}-{module_to_cpu.__class__.__name__}.")
# # cpu_tensor = module_to_cuda.weight.data
# # cuda_tensor = module_to_cpu.weight.data
# # assert cuda_tensor.device.type == "cuda"
# # temp_cpu_tensor = cuda_tensor.to("cpu", non_blocking=True)
# # torch.cuda.current_stream().synchronize()
# # cuda_tensor.copy_(cpu_tensor, non_blocking=True)
# # torch.cuda.current_stream().synchronize()
# # cpu_tensor.copy_(temp_cpu_tensor, non_blocking=True)
# # module_to_cpu.weight.data, module_to_cuda.weight.data = cpu_tensor, cuda_tensor
# cuda_tensor_view = module_to_cpu.weight.data
# cpu_tensor_view = module_to_cuda.weight.data
# module_to_cpu.weight.data = module_to_cpu.weight.to("cpu", non_blocking=True).detach().clone()
# module_to_cuda.weight.data = cuda_tensor_view
# module_to_cuda.weight.data.copy_(cpu_tensor_view)
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def swap_weight_devices_2st(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
stream_to_cpu = torch.cuda.Stream()
stream_to_cuda = torch.cuda.Stream()
events = []
with torch.cuda.stream(stream_to_cpu):
# cuda to offload
offloaded_weights = []
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))
event = torch.cuda.Event()
event.record(stream=stream_to_cpu)
events.append(event)
with torch.cuda.stream(stream_to_cuda):
# cpu to cuda
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event in zip(weight_swap_jobs, events):
event.synchronize()
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
# offload to cpu
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), offloaded_weight in zip(
weight_swap_jobs, offloaded_weights
):
module_to_cpu.weight.data = offloaded_weight
stream_to_cuda.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def swap_weight_devices_failed(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
stream_to_cpu = torch.cuda.Stream()
stream_to_cuda = torch.cuda.Stream()
# cuda to offload
events = []
with torch.cuda.stream(stream_to_cpu):
offloaded_weights = []
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream_to_cpu)
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))
event = torch.cuda.Event()
event.record(stream=stream_to_cpu)
events.append(event)
# cpu to cuda
with torch.cuda.stream(stream_to_cuda):
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event, offloaded_weight in zip(
weight_swap_jobs, events, offloaded_weights
):
event.synchronize()
cuda_data_view.record_stream(stream_to_cuda)
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
module_to_cpu.weight.data = offloaded_weight
stream_to_cuda.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
# torch.cuda.current_stream().wait_stream(stream_to_cuda)
# for job in weight_swap_jobs:
# job[2].record_stream(torch.cuda.current_stream()) # record the ownership of the tensor
def swap_weight_devices_works_2(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
if not (hasattr(module_to_cpu, "offloaded_weight") or hasattr(module_to_cuda, "offloaded_weight")):
# one of the modules must have the tensor to offload
module_to_cpu.offloaded_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
module_to_cpu.offloaded_weight.pin_memory()
offloaded_weight = (
module_to_cpu.offloaded_weight if hasattr(module_to_cpu, "offloaded_weight") else module_to_cuda.offloaded_weight
)
assert module_to_cpu.weight.device.type == "cuda" and module_to_cuda.weight.device.type == "cpu"
weight_swap_jobs.append(
(module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data, offloaded_weight)
)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to offload
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
cuda_data_view.record_stream(stream)
offloaded_weight.copy_(module_to_cpu.weight.data, non_blocking=True)
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
# offload to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
module_to_cpu.weight.data = offloaded_weight
offloaded_weight = cpu_data_view
module_to_cpu.offloaded_weight = offloaded_weight
module_to_cuda.offloaded_weight = offloaded_weight
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def swap_weight_devices_safe_works(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
if not (hasattr(module_to_cpu, "__cached_cpu_weight") or hasattr(module_to_cuda, "__cached_cuda_weight")):
# one of the modules must have the tensor to cache
module_to_cpu.__cached_cpu_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
module_to_cpu.__cached_cpu_weight.pin_memory()
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
for module_to_cpu, module_to_cuda, cuda_tensor_view, cpu_tensor_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_tensor_view.to("cpu", non_blocking=True)
module_to_cuda.weight.data = cpu_tensor_view.to("cuda", non_blocking=True)
torch.cuda.current_stream().synchronize() # wait for the copy from cache to cpu to finish
torch.cuda.empty_cache()
# def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
# assert module_to_cuda.weight.device.type == "cpu" and module_to_cpu.weight.device.type == "cuda"
# weight_on_cuda = module_to_cpu.weight
# weight_on_cpu = module_to_cuda.weight
# cuda_to_cpu_data = weight_on_cuda.data.to("cpu", non_blocking=True)
# event = torch.cuda.current_stream().record_event()
# event.synchronize()
# weight_on_cuda.data.copy_(weight_on_cpu.data, non_blocking=True)
# weight_on_cpu.data = cuda_to_cpu_data
# weight_on_cpu.grad, weight_on_cuda.grad = weight_on_cuda.grad, weight_on_cpu.grad
# module_to_cpu.weight = weight_on_cpu
# module_to_cuda.weight = weight_on_cuda
def weighs_to_device(layer: nn.Module, device: torch.device):
for module in layer.modules():
if hasattr(module, "weight") and module.weight is not None:
module.weight.data = module.weight.data.to(device, non_blocking=True)
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
"""
@@ -313,6 +533,7 @@ class MemoryEfficientSafeOpen:
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
) -> dict[str, torch.Tensor]:
@@ -336,7 +557,6 @@ def load_safetensors(
return state_dict
# endregion
# region Image utils