faster block swap

This commit is contained in:
Kohya S
2024-11-05 21:22:42 +09:00
parent 5e32ee26a1
commit 81c0c965a2
3 changed files with 350 additions and 113 deletions

View File

@@ -17,12 +17,14 @@ import math
import os
from multiprocessing import Value
import time
from typing import List
from typing import List, Optional, Tuple, Union
import toml
from tqdm import tqdm
import torch
import torch.nn as nn
from library import utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -466,45 +468,28 @@ def train(args):
# memory efficient block swapping
def get_block_unit(dbl_blocks, sgl_blocks, index: int):
if index < len(dbl_blocks):
return (dbl_blocks[index],)
else:
index -= len(dbl_blocks)
index *= 2
return (sgl_blocks[index], sgl_blocks[index + 1])
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
# start_time = time.perf_counter()
# print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA")
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
return bidx_to_cpu, bidx_to_cuda # , event
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device):
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc):
# print(f"Backward: Move block {bidx_to_cpu} to CPU")
for block in blocks_to_cpu:
block = block.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
block_to_cpu = blocks[block_idx_to_cpu]
block_to_cuda = blocks[block_idx_to_cuda]
# print(f"Backward: Move block {bidx_to_cuda} to CUDA")
for block in blocks_to_cuda:
block = block.to(dvc, non_blocking=True)
futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
torch.cuda.synchronize()
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}")
return bidx_to_cpu, bidx_to_cuda
blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu)
blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda)
futures[block_idx_to_cuda] = thread_pool.submit(
move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device
)
def wait_blocks_move(block_idx, futures):
if block_idx not in futures:
def wait_blocks_move(block_id, futures):
if block_id not in futures:
return
# print(f"Backward: Wait for block {block_idx}")
# print(f"Backward: Wait for block {block_id}")
# start_time = time.perf_counter()
future = futures.pop(block_idx)
future.result()
# print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
# torch.cuda.synchronize()
future = futures.pop(block_id)
_, bidx_to_cuda = future.result()
assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}"
# print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s")
# print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s")
if args.fused_backward_pass:
@@ -513,11 +498,11 @@ def train(args):
library.adafactor_fused.patch_adafactor_fused(optimizer)
blocks_to_swap = args.blocks_to_swap
double_blocks_to_swap = args.blocks_to_swap // 2
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2
handled_unit_indices = set()
handled_block_ids = set()
n = 1 # only asynchronous purpose, no need to increase this number
# n = 2
@@ -530,28 +515,37 @@ def train(args):
if parameter.requires_grad:
grad_hook = None
if blocks_to_swap:
if double_blocks_to_swap > 0 or single_blocks_to_swap > 0:
is_double = param_name.startswith("double_blocks")
is_single = param_name.startswith("single_blocks")
if is_double or is_single:
if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0:
block_idx = int(param_name.split(".")[1])
unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2
if unit_idx not in handled_unit_indices:
block_id = (is_double, block_idx) # double or single, block index
if block_id not in handled_block_ids:
# swap following (already backpropagated) block
handled_unit_indices.add(unit_idx)
handled_block_ids.add(block_id)
# if n blocks were already backpropagated
num_blocks_propagated = num_block_units - unit_idx - 1
if is_double:
num_blocks = num_double_blocks
blocks_to_swap = double_blocks_to_swap
else:
num_blocks = num_single_blocks
blocks_to_swap = single_blocks_to_swap
# -1 for 0-based index, -1 for current block is not fully backpropagated yet
num_blocks_propagated = num_blocks - block_idx - 2
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
waiting = block_idx > 0 and block_idx <= blocks_to_swap
if swapping or waiting:
block_idx_to_cpu = num_block_units - num_blocks_propagated
block_idx_to_cpu = num_blocks - num_blocks_propagated
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
block_idx_to_wait = unit_idx - 1
block_idx_to_wait = block_idx - 1
# create swap hook
def create_swap_grad_hook(
bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool
is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool
):
def __grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -559,24 +553,25 @@ def train(args):
optimizer.step_param(tensor, param_group)
tensor.grad = None
# print(f"Backward: {uidx}, {swpng}, {wtng}")
# print(
# f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}"
# )
if swpng:
submit_move_blocks(
futures,
thread_pool,
bidx_to_cpu,
bidx_to_cuda,
flux.double_blocks,
flux.single_blocks,
accelerator.device,
flux.double_blocks if is_dbl else flux.single_blocks,
(is_dbl, bidx_to_cuda), # wait for this block
)
if wtng:
wait_blocks_move(bidx_to_wait, futures)
wait_blocks_move((is_dbl, bidx_to_wait), futures)
return __grad_hook
grad_hook = create_swap_grad_hook(
block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting
is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting
)
if grad_hook is None:

View File

@@ -7,8 +7,9 @@ from dataclasses import dataclass
import math
import os
import time
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
from library import utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -923,7 +924,8 @@ class Flux(nn.Module):
self.blocks_to_swap = None
self.thread_pool: Optional[ThreadPoolExecutor] = None
self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2
self.num_double_blocks = len(self.double_blocks)
self.num_single_blocks = len(self.single_blocks)
@property
def device(self):
@@ -963,14 +965,17 @@ class Flux(nn.Module):
def enable_block_swap(self, num_blocks: int):
self.blocks_to_swap = num_blocks
self.double_blocks_to_swap = num_blocks // 2
self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}."
)
n = 1 # async block swap. 1 is enough
# n = 2
# n = max(1, os.cpu_count() // 2)
self.thread_pool = ThreadPoolExecutor(max_workers=n)
def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap:
save_double_blocks = self.double_blocks
save_single_blocks = self.single_blocks
@@ -983,31 +988,55 @@ class Flux(nn.Module):
self.double_blocks = save_double_blocks
self.single_blocks = save_single_blocks
def get_block_unit(self, index: int):
if index < len(self.double_blocks):
return (self.double_blocks[index],)
else:
index -= len(self.double_blocks)
index *= 2
return self.single_blocks[index], self.single_blocks[index + 1]
# def get_block_unit(self, index: int):
# if index < len(self.double_blocks):
# return (self.double_blocks[index],)
# else:
# index -= len(self.double_blocks)
# index *= 2
# return self.single_blocks[index], self.single_blocks[index + 1]
def get_unit_index(self, is_double: bool, index: int):
if is_double:
return index
else:
return len(self.double_blocks) + index // 2
# def get_unit_index(self, is_double: bool, index: int):
# if is_double:
# return index
# else:
# return len(self.double_blocks) + index // 2
def prepare_block_swap_before_forward(self):
# make: first n blocks are on cuda, and last n blocks are on cpu
# # make: first n blocks are on cuda, and last n blocks are on cpu
# if self.blocks_to_swap is None or self.blocks_to_swap == 0:
# # raise ValueError("Block swap is not enabled.")
# return
# for i in range(self.num_block_units - self.blocks_to_swap):
# for b in self.get_block_unit(i):
# b.to(self.device)
# for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
# for b in self.get_block_unit(i):
# b.to("cpu")
# clean_memory_on_device(self.device)
# all blocks are on device, but some weights are on cpu
# make first n blocks weights on device, and last n blocks weights on cpu
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
# raise ValueError("Block swap is not enabled.")
return
for i in range(self.num_block_units - self.blocks_to_swap):
for b in self.get_block_unit(i):
b.to(self.device)
for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
for b in self.get_block_unit(i):
b.to("cpu")
for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]:
b.to(self.device)
utils.weighs_to_device(b, self.device) # make sure weights are on device
for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]:
b.to(self.device) # move block to device first
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
torch.cuda.synchronize()
clean_memory_on_device(self.device)
for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]:
b.to(self.device)
utils.weighs_to_device(b, self.device) # make sure weights are on device
for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]:
b.to(self.device) # move block to device first
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
torch.cuda.synchronize()
clean_memory_on_device(self.device)
def forward(
@@ -1044,27 +1073,22 @@ class Flux(nn.Module):
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
else:
futures = {}
# device = self.device
def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda):
# print(f"Moving {bidx_to_cpu} to cpu.")
for block in blocks_to_cpu:
block.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
# print(f"Moving {bidx_to_cuda} to cuda.")
for block in blocks_to_cuda:
block.to(self.device, non_blocking=True)
torch.cuda.synchronize()
def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
start_time = time.perf_counter()
# print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.")
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
return block_idx_to_cpu, block_idx_to_cuda
blocks_to_cpu = self.get_block_unit(block_idx_to_cpu)
blocks_to_cuda = self.get_block_unit(block_idx_to_cuda)
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
return block_idx_to_cpu, block_idx_to_cuda # , event
block_to_cpu = blocks[block_idx_to_cpu]
block_to_cuda = blocks[block_idx_to_cuda]
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda)
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
def wait_for_blocks_move(block_idx, ftrs):
if block_idx not in ftrs:
@@ -1073,37 +1097,35 @@ class Flux(nn.Module):
# start_time = time.perf_counter()
ftr = ftrs.pop(block_idx)
ftr.result()
# torch.cuda.synchronize()
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
# print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds")
double_futures = {}
for block_idx, block in enumerate(self.double_blocks):
# print(f"Double block {block_idx}")
unit_idx = self.get_unit_index(is_double=True, index=block_idx)
wait_for_blocks_move(unit_idx, futures)
wait_for_blocks_move(block_idx, double_futures)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if unit_idx < self.blocks_to_swap:
block_idx_to_cpu = unit_idx
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
futures[block_idx_to_cuda] = future
if block_idx < self.double_blocks_to_swap:
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx
future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda)
double_futures[block_idx_to_cuda] = future
img = torch.cat((txt, img), 1)
single_futures = {}
for block_idx, block in enumerate(self.single_blocks):
# print(f"Single block {block_idx}")
unit_idx = self.get_unit_index(is_double=False, index=block_idx)
if block_idx % 2 == 0:
wait_for_blocks_move(unit_idx, futures)
wait_for_blocks_move(block_idx, single_futures)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap:
block_idx_to_cpu = unit_idx
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
futures[block_idx_to_cuda] = future
if block_idx < self.single_blocks_to_swap:
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_single_blocks - self.blocks_to_swap + block_idx
future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda)
single_futures[block_idx_to_cuda] = future
img = img[:, txt.shape[1] :, ...]

View File

@@ -6,6 +6,7 @@ import json
import struct
import torch
import torch.nn as nn
from torchvision import transforms
from diffusers import EulerAncestralDiscreteScheduler
import diffusers.schedulers.scheduling_euler_ancestral_discrete
@@ -93,6 +94,225 @@ def setup_logging(args=None, log_level=None, reset=False):
# region PyTorch utils
# def swap_weights(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
# # print(f"Swapping {layer_to_cpu.__class__.__name__}-{module_to_cpu.__class__.__name__}.")
# # cpu_tensor = module_to_cuda.weight.data
# # cuda_tensor = module_to_cpu.weight.data
# # assert cuda_tensor.device.type == "cuda"
# # temp_cpu_tensor = cuda_tensor.to("cpu", non_blocking=True)
# # torch.cuda.current_stream().synchronize()
# # cuda_tensor.copy_(cpu_tensor, non_blocking=True)
# # torch.cuda.current_stream().synchronize()
# # cpu_tensor.copy_(temp_cpu_tensor, non_blocking=True)
# # module_to_cpu.weight.data, module_to_cuda.weight.data = cpu_tensor, cuda_tensor
# cuda_tensor_view = module_to_cpu.weight.data
# cpu_tensor_view = module_to_cuda.weight.data
# module_to_cpu.weight.data = module_to_cpu.weight.to("cpu", non_blocking=True).detach().clone()
# module_to_cuda.weight.data = cuda_tensor_view
# module_to_cuda.weight.data.copy_(cpu_tensor_view)
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def swap_weight_devices_2st(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
stream_to_cpu = torch.cuda.Stream()
stream_to_cuda = torch.cuda.Stream()
events = []
with torch.cuda.stream(stream_to_cpu):
# cuda to offload
offloaded_weights = []
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))
event = torch.cuda.Event()
event.record(stream=stream_to_cpu)
events.append(event)
with torch.cuda.stream(stream_to_cuda):
# cpu to cuda
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event in zip(weight_swap_jobs, events):
event.synchronize()
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
# offload to cpu
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), offloaded_weight in zip(
weight_swap_jobs, offloaded_weights
):
module_to_cpu.weight.data = offloaded_weight
stream_to_cuda.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def swap_weight_devices_failed(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
stream_to_cpu = torch.cuda.Stream()
stream_to_cuda = torch.cuda.Stream()
# cuda to offload
events = []
with torch.cuda.stream(stream_to_cpu):
offloaded_weights = []
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream_to_cpu)
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))
event = torch.cuda.Event()
event.record(stream=stream_to_cpu)
events.append(event)
# cpu to cuda
with torch.cuda.stream(stream_to_cuda):
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event, offloaded_weight in zip(
weight_swap_jobs, events, offloaded_weights
):
event.synchronize()
cuda_data_view.record_stream(stream_to_cuda)
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
module_to_cpu.weight.data = offloaded_weight
stream_to_cuda.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
# torch.cuda.current_stream().wait_stream(stream_to_cuda)
# for job in weight_swap_jobs:
# job[2].record_stream(torch.cuda.current_stream()) # record the ownership of the tensor
def swap_weight_devices_works_2(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
if not (hasattr(module_to_cpu, "offloaded_weight") or hasattr(module_to_cuda, "offloaded_weight")):
# one of the modules must have the tensor to offload
module_to_cpu.offloaded_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
module_to_cpu.offloaded_weight.pin_memory()
offloaded_weight = (
module_to_cpu.offloaded_weight if hasattr(module_to_cpu, "offloaded_weight") else module_to_cuda.offloaded_weight
)
assert module_to_cpu.weight.device.type == "cuda" and module_to_cuda.weight.device.type == "cpu"
weight_swap_jobs.append(
(module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data, offloaded_weight)
)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to offload
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
cuda_data_view.record_stream(stream)
offloaded_weight.copy_(module_to_cpu.weight.data, non_blocking=True)
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
# offload to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
module_to_cpu.weight.data = offloaded_weight
offloaded_weight = cpu_data_view
module_to_cpu.offloaded_weight = offloaded_weight
module_to_cuda.offloaded_weight = offloaded_weight
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def swap_weight_devices_safe_works(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
if not (hasattr(module_to_cpu, "__cached_cpu_weight") or hasattr(module_to_cuda, "__cached_cuda_weight")):
# one of the modules must have the tensor to cache
module_to_cpu.__cached_cpu_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
module_to_cpu.__cached_cpu_weight.pin_memory()
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
for module_to_cpu, module_to_cuda, cuda_tensor_view, cpu_tensor_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_tensor_view.to("cpu", non_blocking=True)
module_to_cuda.weight.data = cpu_tensor_view.to("cuda", non_blocking=True)
torch.cuda.current_stream().synchronize() # wait for the copy from cache to cpu to finish
torch.cuda.empty_cache()
# def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
# assert module_to_cuda.weight.device.type == "cpu" and module_to_cpu.weight.device.type == "cuda"
# weight_on_cuda = module_to_cpu.weight
# weight_on_cpu = module_to_cuda.weight
# cuda_to_cpu_data = weight_on_cuda.data.to("cpu", non_blocking=True)
# event = torch.cuda.current_stream().record_event()
# event.synchronize()
# weight_on_cuda.data.copy_(weight_on_cpu.data, non_blocking=True)
# weight_on_cpu.data = cuda_to_cpu_data
# weight_on_cpu.grad, weight_on_cuda.grad = weight_on_cuda.grad, weight_on_cpu.grad
# module_to_cpu.weight = weight_on_cpu
# module_to_cuda.weight = weight_on_cuda
def weighs_to_device(layer: nn.Module, device: torch.device):
for module in layer.modules():
if hasattr(module, "weight") and module.weight is not None:
module.weight.data = module.weight.data.to(device, non_blocking=True)
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
"""
@@ -313,6 +533,7 @@ class MemoryEfficientSafeOpen:
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
) -> dict[str, torch.Tensor]:
@@ -336,7 +557,6 @@ def load_safetensors(
return state_dict
# endregion
# region Image utils