faster block swap

This commit is contained in:
Kohya S
2024-11-05 21:22:42 +09:00
parent 5e32ee26a1
commit 81c0c965a2
3 changed files with 350 additions and 113 deletions

View File

@@ -17,12 +17,14 @@ import math
import os import os
from multiprocessing import Value from multiprocessing import Value
import time import time
from typing import List from typing import List, Optional, Tuple, Union
import toml import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
import torch.nn as nn
from library import utils
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
init_ipex() init_ipex()
@@ -466,45 +468,28 @@ def train(args):
# memory efficient block swapping # memory efficient block swapping
def get_block_unit(dbl_blocks, sgl_blocks, index: int): def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id):
if index < len(dbl_blocks): def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
return (dbl_blocks[index],) # start_time = time.perf_counter()
else: # print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA")
index -= len(dbl_blocks) utils.swap_weight_devices(block_to_cpu, block_to_cuda)
index *= 2 # print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
return (sgl_blocks[index], sgl_blocks[index + 1]) return bidx_to_cpu, bidx_to_cuda # , event
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device): block_to_cpu = blocks[block_idx_to_cpu]
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc): block_to_cuda = blocks[block_idx_to_cuda]
# print(f"Backward: Move block {bidx_to_cpu} to CPU")
for block in blocks_to_cpu:
block = block.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
# print(f"Backward: Move block {bidx_to_cuda} to CUDA") futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
for block in blocks_to_cuda:
block = block.to(dvc, non_blocking=True)
torch.cuda.synchronize() def wait_blocks_move(block_id, futures):
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}") if block_id not in futures:
return bidx_to_cpu, bidx_to_cuda
blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu)
blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda)
futures[block_idx_to_cuda] = thread_pool.submit(
move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device
)
def wait_blocks_move(block_idx, futures):
if block_idx not in futures:
return return
# print(f"Backward: Wait for block {block_idx}") # print(f"Backward: Wait for block {block_id}")
# start_time = time.perf_counter() # start_time = time.perf_counter()
future = futures.pop(block_idx) future = futures.pop(block_id)
future.result() _, bidx_to_cuda = future.result()
# print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}"
# torch.cuda.synchronize() # print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s")
# print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s") # print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s")
if args.fused_backward_pass: if args.fused_backward_pass:
@@ -513,11 +498,11 @@ def train(args):
library.adafactor_fused.patch_adafactor_fused(optimizer) library.adafactor_fused.patch_adafactor_fused(optimizer)
blocks_to_swap = args.blocks_to_swap double_blocks_to_swap = args.blocks_to_swap // 2
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2 handled_block_ids = set()
handled_unit_indices = set()
n = 1 # only asynchronous purpose, no need to increase this number n = 1 # only asynchronous purpose, no need to increase this number
# n = 2 # n = 2
@@ -530,28 +515,37 @@ def train(args):
if parameter.requires_grad: if parameter.requires_grad:
grad_hook = None grad_hook = None
if blocks_to_swap: if double_blocks_to_swap > 0 or single_blocks_to_swap > 0:
is_double = param_name.startswith("double_blocks") is_double = param_name.startswith("double_blocks")
is_single = param_name.startswith("single_blocks") is_single = param_name.startswith("single_blocks")
if is_double or is_single: if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0:
block_idx = int(param_name.split(".")[1]) block_idx = int(param_name.split(".")[1])
unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2 block_id = (is_double, block_idx) # double or single, block index
if unit_idx not in handled_unit_indices: if block_id not in handled_block_ids:
# swap following (already backpropagated) block # swap following (already backpropagated) block
handled_unit_indices.add(unit_idx) handled_block_ids.add(block_id)
# if n blocks were already backpropagated # if n blocks were already backpropagated
num_blocks_propagated = num_block_units - unit_idx - 1 if is_double:
num_blocks = num_double_blocks
blocks_to_swap = double_blocks_to_swap
else:
num_blocks = num_single_blocks
blocks_to_swap = single_blocks_to_swap
# -1 for 0-based index, -1 for current block is not fully backpropagated yet
num_blocks_propagated = num_blocks - block_idx - 2
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap waiting = block_idx > 0 and block_idx <= blocks_to_swap
if swapping or waiting: if swapping or waiting:
block_idx_to_cpu = num_block_units - num_blocks_propagated block_idx_to_cpu = num_blocks - num_blocks_propagated
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
block_idx_to_wait = unit_idx - 1 block_idx_to_wait = block_idx - 1
# create swap hook # create swap hook
def create_swap_grad_hook( def create_swap_grad_hook(
bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool
): ):
def __grad_hook(tensor: torch.Tensor): def __grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0: if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -559,24 +553,25 @@ def train(args):
optimizer.step_param(tensor, param_group) optimizer.step_param(tensor, param_group)
tensor.grad = None tensor.grad = None
# print(f"Backward: {uidx}, {swpng}, {wtng}") # print(
# f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}"
# )
if swpng: if swpng:
submit_move_blocks( submit_move_blocks(
futures, futures,
thread_pool, thread_pool,
bidx_to_cpu, bidx_to_cpu,
bidx_to_cuda, bidx_to_cuda,
flux.double_blocks, flux.double_blocks if is_dbl else flux.single_blocks,
flux.single_blocks, (is_dbl, bidx_to_cuda), # wait for this block
accelerator.device,
) )
if wtng: if wtng:
wait_blocks_move(bidx_to_wait, futures) wait_blocks_move((is_dbl, bidx_to_wait), futures)
return __grad_hook return __grad_hook
grad_hook = create_swap_grad_hook( grad_hook = create_swap_grad_hook(
block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting
) )
if grad_hook is None: if grad_hook is None:

View File

@@ -7,8 +7,9 @@ from dataclasses import dataclass
import math import math
import os import os
import time import time
from typing import Dict, List, Optional from typing import Dict, List, Optional, Union
from library import utils
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
init_ipex() init_ipex()
@@ -923,7 +924,8 @@ class Flux(nn.Module):
self.blocks_to_swap = None self.blocks_to_swap = None
self.thread_pool: Optional[ThreadPoolExecutor] = None self.thread_pool: Optional[ThreadPoolExecutor] = None
self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2 self.num_double_blocks = len(self.double_blocks)
self.num_single_blocks = len(self.single_blocks)
@property @property
def device(self): def device(self):
@@ -963,14 +965,17 @@ class Flux(nn.Module):
def enable_block_swap(self, num_blocks: int): def enable_block_swap(self, num_blocks: int):
self.blocks_to_swap = num_blocks self.blocks_to_swap = num_blocks
self.double_blocks_to_swap = num_blocks // 2
self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}."
)
n = 1 # async block swap. 1 is enough n = 1 # async block swap. 1 is enough
# n = 2
# n = max(1, os.cpu_count() // 2)
self.thread_pool = ThreadPoolExecutor(max_workers=n) self.thread_pool = ThreadPoolExecutor(max_workers=n)
def move_to_device_except_swap_blocks(self, device: torch.device): def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap: if self.blocks_to_swap:
save_double_blocks = self.double_blocks save_double_blocks = self.double_blocks
save_single_blocks = self.single_blocks save_single_blocks = self.single_blocks
@@ -983,31 +988,55 @@ class Flux(nn.Module):
self.double_blocks = save_double_blocks self.double_blocks = save_double_blocks
self.single_blocks = save_single_blocks self.single_blocks = save_single_blocks
def get_block_unit(self, index: int): # def get_block_unit(self, index: int):
if index < len(self.double_blocks): # if index < len(self.double_blocks):
return (self.double_blocks[index],) # return (self.double_blocks[index],)
else: # else:
index -= len(self.double_blocks) # index -= len(self.double_blocks)
index *= 2 # index *= 2
return self.single_blocks[index], self.single_blocks[index + 1] # return self.single_blocks[index], self.single_blocks[index + 1]
def get_unit_index(self, is_double: bool, index: int): # def get_unit_index(self, is_double: bool, index: int):
if is_double: # if is_double:
return index # return index
else: # else:
return len(self.double_blocks) + index // 2 # return len(self.double_blocks) + index // 2
def prepare_block_swap_before_forward(self): def prepare_block_swap_before_forward(self):
# make: first n blocks are on cuda, and last n blocks are on cpu # # make: first n blocks are on cuda, and last n blocks are on cpu
# if self.blocks_to_swap is None or self.blocks_to_swap == 0:
# # raise ValueError("Block swap is not enabled.")
# return
# for i in range(self.num_block_units - self.blocks_to_swap):
# for b in self.get_block_unit(i):
# b.to(self.device)
# for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
# for b in self.get_block_unit(i):
# b.to("cpu")
# clean_memory_on_device(self.device)
# all blocks are on device, but some weights are on cpu
# make first n blocks weights on device, and last n blocks weights on cpu
if self.blocks_to_swap is None or self.blocks_to_swap == 0: if self.blocks_to_swap is None or self.blocks_to_swap == 0:
# raise ValueError("Block swap is not enabled.") # raise ValueError("Block swap is not enabled.")
return return
for i in range(self.num_block_units - self.blocks_to_swap):
for b in self.get_block_unit(i): for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]:
b.to(self.device) b.to(self.device)
for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units): utils.weighs_to_device(b, self.device) # make sure weights are on device
for b in self.get_block_unit(i): for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]:
b.to("cpu") b.to(self.device) # move block to device first
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
torch.cuda.synchronize()
clean_memory_on_device(self.device)
for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]:
b.to(self.device)
utils.weighs_to_device(b, self.device) # make sure weights are on device
for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]:
b.to(self.device) # move block to device first
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
torch.cuda.synchronize()
clean_memory_on_device(self.device) clean_memory_on_device(self.device)
def forward( def forward(
@@ -1044,27 +1073,22 @@ class Flux(nn.Module):
for block in self.single_blocks: for block in self.single_blocks:
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
else: else:
futures = {} # device = self.device
def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda): def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda): def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
# print(f"Moving {bidx_to_cpu} to cpu.") start_time = time.perf_counter()
for block in blocks_to_cpu: # print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.")
block.to("cpu", non_blocking=True) utils.swap_weight_devices(block_to_cpu, block_to_cuda)
torch.cuda.empty_cache()
# print(f"Moving {bidx_to_cuda} to cuda.")
for block in blocks_to_cuda:
block.to(self.device, non_blocking=True)
torch.cuda.synchronize()
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
return block_idx_to_cpu, block_idx_to_cuda
blocks_to_cpu = self.get_block_unit(block_idx_to_cpu) # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
blocks_to_cuda = self.get_block_unit(block_idx_to_cuda) return block_idx_to_cpu, block_idx_to_cuda # , event
block_to_cpu = blocks[block_idx_to_cpu]
block_to_cuda = blocks[block_idx_to_cuda]
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda) return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
def wait_for_blocks_move(block_idx, ftrs): def wait_for_blocks_move(block_idx, ftrs):
if block_idx not in ftrs: if block_idx not in ftrs:
@@ -1073,37 +1097,35 @@ class Flux(nn.Module):
# start_time = time.perf_counter() # start_time = time.perf_counter()
ftr = ftrs.pop(block_idx) ftr = ftrs.pop(block_idx)
ftr.result() ftr.result()
# torch.cuda.synchronize() # print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds")
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
double_futures = {}
for block_idx, block in enumerate(self.double_blocks): for block_idx, block in enumerate(self.double_blocks):
# print(f"Double block {block_idx}") # print(f"Double block {block_idx}")
unit_idx = self.get_unit_index(is_double=True, index=block_idx) wait_for_blocks_move(block_idx, double_futures)
wait_for_blocks_move(unit_idx, futures)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if unit_idx < self.blocks_to_swap: if block_idx < self.double_blocks_to_swap:
block_idx_to_cpu = unit_idx block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda)
futures[block_idx_to_cuda] = future double_futures[block_idx_to_cuda] = future
img = torch.cat((txt, img), 1) img = torch.cat((txt, img), 1)
single_futures = {}
for block_idx, block in enumerate(self.single_blocks): for block_idx, block in enumerate(self.single_blocks):
# print(f"Single block {block_idx}") # print(f"Single block {block_idx}")
unit_idx = self.get_unit_index(is_double=False, index=block_idx) wait_for_blocks_move(block_idx, single_futures)
if block_idx % 2 == 0:
wait_for_blocks_move(unit_idx, futures)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap: if block_idx < self.single_blocks_to_swap:
block_idx_to_cpu = unit_idx block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx block_idx_to_cuda = self.num_single_blocks - self.blocks_to_swap + block_idx
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda)
futures[block_idx_to_cuda] = future single_futures[block_idx_to_cuda] = future
img = img[:, txt.shape[1] :, ...] img = img[:, txt.shape[1] :, ...]

View File

@@ -6,6 +6,7 @@ import json
import struct import struct
import torch import torch
import torch.nn as nn
from torchvision import transforms from torchvision import transforms
from diffusers import EulerAncestralDiscreteScheduler from diffusers import EulerAncestralDiscreteScheduler
import diffusers.schedulers.scheduling_euler_ancestral_discrete import diffusers.schedulers.scheduling_euler_ancestral_discrete
@@ -93,6 +94,225 @@ def setup_logging(args=None, log_level=None, reset=False):
# region PyTorch utils # region PyTorch utils
# def swap_weights(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
# # print(f"Swapping {layer_to_cpu.__class__.__name__}-{module_to_cpu.__class__.__name__}.")
# # cpu_tensor = module_to_cuda.weight.data
# # cuda_tensor = module_to_cpu.weight.data
# # assert cuda_tensor.device.type == "cuda"
# # temp_cpu_tensor = cuda_tensor.to("cpu", non_blocking=True)
# # torch.cuda.current_stream().synchronize()
# # cuda_tensor.copy_(cpu_tensor, non_blocking=True)
# # torch.cuda.current_stream().synchronize()
# # cpu_tensor.copy_(temp_cpu_tensor, non_blocking=True)
# # module_to_cpu.weight.data, module_to_cuda.weight.data = cpu_tensor, cuda_tensor
# cuda_tensor_view = module_to_cpu.weight.data
# cpu_tensor_view = module_to_cuda.weight.data
# module_to_cpu.weight.data = module_to_cpu.weight.to("cpu", non_blocking=True).detach().clone()
# module_to_cuda.weight.data = cuda_tensor_view
# module_to_cuda.weight.data.copy_(cpu_tensor_view)
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def swap_weight_devices_2st(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
stream_to_cpu = torch.cuda.Stream()
stream_to_cuda = torch.cuda.Stream()
events = []
with torch.cuda.stream(stream_to_cpu):
# cuda to offload
offloaded_weights = []
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))
event = torch.cuda.Event()
event.record(stream=stream_to_cpu)
events.append(event)
with torch.cuda.stream(stream_to_cuda):
# cpu to cuda
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event in zip(weight_swap_jobs, events):
event.synchronize()
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
# offload to cpu
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), offloaded_weight in zip(
weight_swap_jobs, offloaded_weights
):
module_to_cpu.weight.data = offloaded_weight
stream_to_cuda.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def swap_weight_devices_failed(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
stream_to_cpu = torch.cuda.Stream()
stream_to_cuda = torch.cuda.Stream()
# cuda to offload
events = []
with torch.cuda.stream(stream_to_cpu):
offloaded_weights = []
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream_to_cpu)
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))
event = torch.cuda.Event()
event.record(stream=stream_to_cpu)
events.append(event)
# cpu to cuda
with torch.cuda.stream(stream_to_cuda):
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event, offloaded_weight in zip(
weight_swap_jobs, events, offloaded_weights
):
event.synchronize()
cuda_data_view.record_stream(stream_to_cuda)
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
module_to_cpu.weight.data = offloaded_weight
stream_to_cuda.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
# torch.cuda.current_stream().wait_stream(stream_to_cuda)
# for job in weight_swap_jobs:
# job[2].record_stream(torch.cuda.current_stream()) # record the ownership of the tensor
def swap_weight_devices_works_2(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
if not (hasattr(module_to_cpu, "offloaded_weight") or hasattr(module_to_cuda, "offloaded_weight")):
# one of the modules must have the tensor to offload
module_to_cpu.offloaded_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
module_to_cpu.offloaded_weight.pin_memory()
offloaded_weight = (
module_to_cpu.offloaded_weight if hasattr(module_to_cpu, "offloaded_weight") else module_to_cuda.offloaded_weight
)
assert module_to_cpu.weight.device.type == "cuda" and module_to_cuda.weight.device.type == "cpu"
weight_swap_jobs.append(
(module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data, offloaded_weight)
)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to offload
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
cuda_data_view.record_stream(stream)
offloaded_weight.copy_(module_to_cpu.weight.data, non_blocking=True)
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
# offload to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
module_to_cpu.weight.data = offloaded_weight
offloaded_weight = cpu_data_view
module_to_cpu.offloaded_weight = offloaded_weight
module_to_cuda.offloaded_weight = offloaded_weight
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def swap_weight_devices_safe_works(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
if not (hasattr(module_to_cpu, "__cached_cpu_weight") or hasattr(module_to_cuda, "__cached_cuda_weight")):
# one of the modules must have the tensor to cache
module_to_cpu.__cached_cpu_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
module_to_cpu.__cached_cpu_weight.pin_memory()
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
for module_to_cpu, module_to_cuda, cuda_tensor_view, cpu_tensor_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_tensor_view.to("cpu", non_blocking=True)
module_to_cuda.weight.data = cpu_tensor_view.to("cuda", non_blocking=True)
torch.cuda.current_stream().synchronize() # wait for the copy from cache to cpu to finish
torch.cuda.empty_cache()
# def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
# assert module_to_cuda.weight.device.type == "cpu" and module_to_cpu.weight.device.type == "cuda"
# weight_on_cuda = module_to_cpu.weight
# weight_on_cpu = module_to_cuda.weight
# cuda_to_cpu_data = weight_on_cuda.data.to("cpu", non_blocking=True)
# event = torch.cuda.current_stream().record_event()
# event.synchronize()
# weight_on_cuda.data.copy_(weight_on_cpu.data, non_blocking=True)
# weight_on_cpu.data = cuda_to_cpu_data
# weight_on_cpu.grad, weight_on_cuda.grad = weight_on_cuda.grad, weight_on_cpu.grad
# module_to_cpu.weight = weight_on_cpu
# module_to_cuda.weight = weight_on_cuda
def weighs_to_device(layer: nn.Module, device: torch.device):
for module in layer.modules():
if hasattr(module, "weight") and module.weight is not None:
module.weight.data = module.weight.data.to(device, non_blocking=True)
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
""" """
@@ -313,6 +533,7 @@ class MemoryEfficientSafeOpen:
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape) # return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def load_safetensors( def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32 path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
@@ -336,7 +557,6 @@ def load_safetensors(
return state_dict return state_dict
# endregion # endregion
# region Image utils # region Image utils