mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
faster block swap
This commit is contained in:
103
flux_train.py
103
flux_train.py
@@ -17,12 +17,14 @@ import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
import time
|
||||
from typing import List
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from library import utils
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
@@ -466,45 +468,28 @@ def train(args):
|
||||
|
||||
# memory efficient block swapping
|
||||
|
||||
def get_block_unit(dbl_blocks, sgl_blocks, index: int):
|
||||
if index < len(dbl_blocks):
|
||||
return (dbl_blocks[index],)
|
||||
else:
|
||||
index -= len(dbl_blocks)
|
||||
index *= 2
|
||||
return (sgl_blocks[index], sgl_blocks[index + 1])
|
||||
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id):
|
||||
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
||||
# start_time = time.perf_counter()
|
||||
# print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA")
|
||||
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
|
||||
return bidx_to_cpu, bidx_to_cuda # , event
|
||||
|
||||
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device):
|
||||
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc):
|
||||
# print(f"Backward: Move block {bidx_to_cpu} to CPU")
|
||||
for block in blocks_to_cpu:
|
||||
block = block.to("cpu", non_blocking=True)
|
||||
torch.cuda.empty_cache()
|
||||
block_to_cpu = blocks[block_idx_to_cpu]
|
||||
block_to_cuda = blocks[block_idx_to_cuda]
|
||||
|
||||
# print(f"Backward: Move block {bidx_to_cuda} to CUDA")
|
||||
for block in blocks_to_cuda:
|
||||
block = block.to(dvc, non_blocking=True)
|
||||
futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}")
|
||||
return bidx_to_cpu, bidx_to_cuda
|
||||
|
||||
blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu)
|
||||
blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda)
|
||||
|
||||
futures[block_idx_to_cuda] = thread_pool.submit(
|
||||
move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device
|
||||
)
|
||||
|
||||
def wait_blocks_move(block_idx, futures):
|
||||
if block_idx not in futures:
|
||||
def wait_blocks_move(block_id, futures):
|
||||
if block_id not in futures:
|
||||
return
|
||||
# print(f"Backward: Wait for block {block_idx}")
|
||||
# print(f"Backward: Wait for block {block_id}")
|
||||
# start_time = time.perf_counter()
|
||||
future = futures.pop(block_idx)
|
||||
future.result()
|
||||
# print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
|
||||
# torch.cuda.synchronize()
|
||||
future = futures.pop(block_id)
|
||||
_, bidx_to_cuda = future.result()
|
||||
assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}"
|
||||
# print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s")
|
||||
# print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s")
|
||||
|
||||
if args.fused_backward_pass:
|
||||
@@ -513,11 +498,11 @@ def train(args):
|
||||
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
|
||||
blocks_to_swap = args.blocks_to_swap
|
||||
double_blocks_to_swap = args.blocks_to_swap // 2
|
||||
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2
|
||||
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
|
||||
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
|
||||
num_block_units = num_double_blocks + num_single_blocks // 2
|
||||
handled_unit_indices = set()
|
||||
handled_block_ids = set()
|
||||
|
||||
n = 1 # only asynchronous purpose, no need to increase this number
|
||||
# n = 2
|
||||
@@ -530,28 +515,37 @@ def train(args):
|
||||
if parameter.requires_grad:
|
||||
grad_hook = None
|
||||
|
||||
if blocks_to_swap:
|
||||
if double_blocks_to_swap > 0 or single_blocks_to_swap > 0:
|
||||
is_double = param_name.startswith("double_blocks")
|
||||
is_single = param_name.startswith("single_blocks")
|
||||
if is_double or is_single:
|
||||
if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0:
|
||||
block_idx = int(param_name.split(".")[1])
|
||||
unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2
|
||||
if unit_idx not in handled_unit_indices:
|
||||
block_id = (is_double, block_idx) # double or single, block index
|
||||
if block_id not in handled_block_ids:
|
||||
# swap following (already backpropagated) block
|
||||
handled_unit_indices.add(unit_idx)
|
||||
handled_block_ids.add(block_id)
|
||||
|
||||
# if n blocks were already backpropagated
|
||||
num_blocks_propagated = num_block_units - unit_idx - 1
|
||||
if is_double:
|
||||
num_blocks = num_double_blocks
|
||||
blocks_to_swap = double_blocks_to_swap
|
||||
else:
|
||||
num_blocks = num_single_blocks
|
||||
blocks_to_swap = single_blocks_to_swap
|
||||
|
||||
# -1 for 0-based index, -1 for current block is not fully backpropagated yet
|
||||
num_blocks_propagated = num_blocks - block_idx - 2
|
||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
|
||||
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
|
||||
waiting = block_idx > 0 and block_idx <= blocks_to_swap
|
||||
|
||||
if swapping or waiting:
|
||||
block_idx_to_cpu = num_block_units - num_blocks_propagated
|
||||
block_idx_to_cpu = num_blocks - num_blocks_propagated
|
||||
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
|
||||
block_idx_to_wait = unit_idx - 1
|
||||
block_idx_to_wait = block_idx - 1
|
||||
|
||||
# create swap hook
|
||||
def create_swap_grad_hook(
|
||||
bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool
|
||||
is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool
|
||||
):
|
||||
def __grad_hook(tensor: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
@@ -559,24 +553,25 @@ def train(args):
|
||||
optimizer.step_param(tensor, param_group)
|
||||
tensor.grad = None
|
||||
|
||||
# print(f"Backward: {uidx}, {swpng}, {wtng}")
|
||||
# print(
|
||||
# f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}"
|
||||
# )
|
||||
if swpng:
|
||||
submit_move_blocks(
|
||||
futures,
|
||||
thread_pool,
|
||||
bidx_to_cpu,
|
||||
bidx_to_cuda,
|
||||
flux.double_blocks,
|
||||
flux.single_blocks,
|
||||
accelerator.device,
|
||||
flux.double_blocks if is_dbl else flux.single_blocks,
|
||||
(is_dbl, bidx_to_cuda), # wait for this block
|
||||
)
|
||||
if wtng:
|
||||
wait_blocks_move(bidx_to_wait, futures)
|
||||
wait_blocks_move((is_dbl, bidx_to_wait), futures)
|
||||
|
||||
return __grad_hook
|
||||
|
||||
grad_hook = create_swap_grad_hook(
|
||||
block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting
|
||||
is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting
|
||||
)
|
||||
|
||||
if grad_hook is None:
|
||||
|
||||
@@ -7,8 +7,9 @@ from dataclasses import dataclass
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from library import utils
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
@@ -923,7 +924,8 @@ class Flux(nn.Module):
|
||||
self.blocks_to_swap = None
|
||||
|
||||
self.thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2
|
||||
self.num_double_blocks = len(self.double_blocks)
|
||||
self.num_single_blocks = len(self.single_blocks)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
@@ -963,14 +965,17 @@ class Flux(nn.Module):
|
||||
|
||||
def enable_block_swap(self, num_blocks: int):
|
||||
self.blocks_to_swap = num_blocks
|
||||
self.double_blocks_to_swap = num_blocks // 2
|
||||
self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2
|
||||
print(
|
||||
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}."
|
||||
)
|
||||
|
||||
n = 1 # async block swap. 1 is enough
|
||||
# n = 2
|
||||
# n = max(1, os.cpu_count() // 2)
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=n)
|
||||
|
||||
def move_to_device_except_swap_blocks(self, device: torch.device):
|
||||
# assume model is on cpu
|
||||
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
||||
if self.blocks_to_swap:
|
||||
save_double_blocks = self.double_blocks
|
||||
save_single_blocks = self.single_blocks
|
||||
@@ -983,31 +988,55 @@ class Flux(nn.Module):
|
||||
self.double_blocks = save_double_blocks
|
||||
self.single_blocks = save_single_blocks
|
||||
|
||||
def get_block_unit(self, index: int):
|
||||
if index < len(self.double_blocks):
|
||||
return (self.double_blocks[index],)
|
||||
else:
|
||||
index -= len(self.double_blocks)
|
||||
index *= 2
|
||||
return self.single_blocks[index], self.single_blocks[index + 1]
|
||||
# def get_block_unit(self, index: int):
|
||||
# if index < len(self.double_blocks):
|
||||
# return (self.double_blocks[index],)
|
||||
# else:
|
||||
# index -= len(self.double_blocks)
|
||||
# index *= 2
|
||||
# return self.single_blocks[index], self.single_blocks[index + 1]
|
||||
|
||||
def get_unit_index(self, is_double: bool, index: int):
|
||||
if is_double:
|
||||
return index
|
||||
else:
|
||||
return len(self.double_blocks) + index // 2
|
||||
# def get_unit_index(self, is_double: bool, index: int):
|
||||
# if is_double:
|
||||
# return index
|
||||
# else:
|
||||
# return len(self.double_blocks) + index // 2
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
# make: first n blocks are on cuda, and last n blocks are on cpu
|
||||
# # make: first n blocks are on cuda, and last n blocks are on cpu
|
||||
# if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
# # raise ValueError("Block swap is not enabled.")
|
||||
# return
|
||||
# for i in range(self.num_block_units - self.blocks_to_swap):
|
||||
# for b in self.get_block_unit(i):
|
||||
# b.to(self.device)
|
||||
# for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
|
||||
# for b in self.get_block_unit(i):
|
||||
# b.to("cpu")
|
||||
# clean_memory_on_device(self.device)
|
||||
|
||||
# all blocks are on device, but some weights are on cpu
|
||||
# make first n blocks weights on device, and last n blocks weights on cpu
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
# raise ValueError("Block swap is not enabled.")
|
||||
return
|
||||
for i in range(self.num_block_units - self.blocks_to_swap):
|
||||
for b in self.get_block_unit(i):
|
||||
b.to(self.device)
|
||||
for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
|
||||
for b in self.get_block_unit(i):
|
||||
b.to("cpu")
|
||||
|
||||
for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]:
|
||||
b.to(self.device)
|
||||
utils.weighs_to_device(b, self.device) # make sure weights are on device
|
||||
for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]:
|
||||
b.to(self.device) # move block to device first
|
||||
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
|
||||
torch.cuda.synchronize()
|
||||
clean_memory_on_device(self.device)
|
||||
|
||||
for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]:
|
||||
b.to(self.device)
|
||||
utils.weighs_to_device(b, self.device) # make sure weights are on device
|
||||
for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]:
|
||||
b.to(self.device) # move block to device first
|
||||
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
|
||||
torch.cuda.synchronize()
|
||||
clean_memory_on_device(self.device)
|
||||
|
||||
def forward(
|
||||
@@ -1044,27 +1073,22 @@ class Flux(nn.Module):
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
else:
|
||||
futures = {}
|
||||
# device = self.device
|
||||
|
||||
def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda):
|
||||
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda):
|
||||
# print(f"Moving {bidx_to_cpu} to cpu.")
|
||||
for block in blocks_to_cpu:
|
||||
block.to("cpu", non_blocking=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# print(f"Moving {bidx_to_cuda} to cuda.")
|
||||
for block in blocks_to_cuda:
|
||||
block.to(self.device, non_blocking=True)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda):
|
||||
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
||||
start_time = time.perf_counter()
|
||||
# print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.")
|
||||
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
|
||||
return block_idx_to_cpu, block_idx_to_cuda
|
||||
|
||||
blocks_to_cpu = self.get_block_unit(block_idx_to_cpu)
|
||||
blocks_to_cuda = self.get_block_unit(block_idx_to_cuda)
|
||||
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
|
||||
return block_idx_to_cpu, block_idx_to_cuda # , event
|
||||
|
||||
block_to_cpu = blocks[block_idx_to_cpu]
|
||||
block_to_cuda = blocks[block_idx_to_cuda]
|
||||
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
|
||||
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda)
|
||||
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
|
||||
|
||||
def wait_for_blocks_move(block_idx, ftrs):
|
||||
if block_idx not in ftrs:
|
||||
@@ -1073,37 +1097,35 @@ class Flux(nn.Module):
|
||||
# start_time = time.perf_counter()
|
||||
ftr = ftrs.pop(block_idx)
|
||||
ftr.result()
|
||||
# torch.cuda.synchronize()
|
||||
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
|
||||
# print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds")
|
||||
|
||||
double_futures = {}
|
||||
for block_idx, block in enumerate(self.double_blocks):
|
||||
# print(f"Double block {block_idx}")
|
||||
unit_idx = self.get_unit_index(is_double=True, index=block_idx)
|
||||
wait_for_blocks_move(unit_idx, futures)
|
||||
wait_for_blocks_move(block_idx, double_futures)
|
||||
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
|
||||
if unit_idx < self.blocks_to_swap:
|
||||
block_idx_to_cpu = unit_idx
|
||||
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
|
||||
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
|
||||
futures[block_idx_to_cuda] = future
|
||||
if block_idx < self.double_blocks_to_swap:
|
||||
block_idx_to_cpu = block_idx
|
||||
block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx
|
||||
future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda)
|
||||
double_futures[block_idx_to_cuda] = future
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
single_futures = {}
|
||||
for block_idx, block in enumerate(self.single_blocks):
|
||||
# print(f"Single block {block_idx}")
|
||||
unit_idx = self.get_unit_index(is_double=False, index=block_idx)
|
||||
if block_idx % 2 == 0:
|
||||
wait_for_blocks_move(unit_idx, futures)
|
||||
wait_for_blocks_move(block_idx, single_futures)
|
||||
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
|
||||
if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap:
|
||||
block_idx_to_cpu = unit_idx
|
||||
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
|
||||
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
|
||||
futures[block_idx_to_cuda] = future
|
||||
if block_idx < self.single_blocks_to_swap:
|
||||
block_idx_to_cpu = block_idx
|
||||
block_idx_to_cuda = self.num_single_blocks - self.blocks_to_swap + block_idx
|
||||
future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda)
|
||||
single_futures[block_idx_to_cuda] = future
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
|
||||
222
library/utils.py
222
library/utils.py
@@ -6,6 +6,7 @@ import json
|
||||
import struct
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import transforms
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
import diffusers.schedulers.scheduling_euler_ancestral_discrete
|
||||
@@ -93,6 +94,225 @@ def setup_logging(args=None, log_level=None, reset=False):
|
||||
|
||||
# region PyTorch utils
|
||||
|
||||
# def swap_weights(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
# # print(f"Swapping {layer_to_cpu.__class__.__name__}-{module_to_cpu.__class__.__name__}.")
|
||||
# # cpu_tensor = module_to_cuda.weight.data
|
||||
# # cuda_tensor = module_to_cpu.weight.data
|
||||
# # assert cuda_tensor.device.type == "cuda"
|
||||
# # temp_cpu_tensor = cuda_tensor.to("cpu", non_blocking=True)
|
||||
# # torch.cuda.current_stream().synchronize()
|
||||
# # cuda_tensor.copy_(cpu_tensor, non_blocking=True)
|
||||
# # torch.cuda.current_stream().synchronize()
|
||||
# # cpu_tensor.copy_(temp_cpu_tensor, non_blocking=True)
|
||||
# # module_to_cpu.weight.data, module_to_cuda.weight.data = cpu_tensor, cuda_tensor
|
||||
# cuda_tensor_view = module_to_cpu.weight.data
|
||||
# cpu_tensor_view = module_to_cuda.weight.data
|
||||
# module_to_cpu.weight.data = module_to_cpu.weight.to("cpu", non_blocking=True).detach().clone()
|
||||
# module_to_cuda.weight.data = cuda_tensor_view
|
||||
# module_to_cuda.weight.data.copy_(cpu_tensor_view)
|
||||
|
||||
|
||||
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream)
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
stream.synchronize()
|
||||
|
||||
# cpu to cuda
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
stream.synchronize()
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
|
||||
def swap_weight_devices_2st(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
|
||||
stream_to_cpu = torch.cuda.Stream()
|
||||
stream_to_cuda = torch.cuda.Stream()
|
||||
|
||||
events = []
|
||||
with torch.cuda.stream(stream_to_cpu):
|
||||
# cuda to offload
|
||||
offloaded_weights = []
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))
|
||||
event = torch.cuda.Event()
|
||||
event.record(stream=stream_to_cpu)
|
||||
events.append(event)
|
||||
|
||||
with torch.cuda.stream(stream_to_cuda):
|
||||
# cpu to cuda
|
||||
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event in zip(weight_swap_jobs, events):
|
||||
event.synchronize()
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
# offload to cpu
|
||||
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), offloaded_weight in zip(
|
||||
weight_swap_jobs, offloaded_weights
|
||||
):
|
||||
module_to_cpu.weight.data = offloaded_weight
|
||||
|
||||
stream_to_cuda.synchronize()
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
|
||||
def swap_weight_devices_failed(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
|
||||
stream_to_cpu = torch.cuda.Stream()
|
||||
stream_to_cuda = torch.cuda.Stream()
|
||||
|
||||
# cuda to offload
|
||||
events = []
|
||||
with torch.cuda.stream(stream_to_cpu):
|
||||
offloaded_weights = []
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream_to_cpu)
|
||||
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))
|
||||
|
||||
event = torch.cuda.Event()
|
||||
event.record(stream=stream_to_cpu)
|
||||
events.append(event)
|
||||
|
||||
# cpu to cuda
|
||||
with torch.cuda.stream(stream_to_cuda):
|
||||
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event, offloaded_weight in zip(
|
||||
weight_swap_jobs, events, offloaded_weights
|
||||
):
|
||||
event.synchronize()
|
||||
cuda_data_view.record_stream(stream_to_cuda)
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
module_to_cpu.weight.data = offloaded_weight
|
||||
|
||||
stream_to_cuda.synchronize()
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
# torch.cuda.current_stream().wait_stream(stream_to_cuda)
|
||||
# for job in weight_swap_jobs:
|
||||
# job[2].record_stream(torch.cuda.current_stream()) # record the ownership of the tensor
|
||||
|
||||
|
||||
def swap_weight_devices_works_2(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
if not (hasattr(module_to_cpu, "offloaded_weight") or hasattr(module_to_cuda, "offloaded_weight")):
|
||||
# one of the modules must have the tensor to offload
|
||||
module_to_cpu.offloaded_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
|
||||
module_to_cpu.offloaded_weight.pin_memory()
|
||||
offloaded_weight = (
|
||||
module_to_cpu.offloaded_weight if hasattr(module_to_cpu, "offloaded_weight") else module_to_cuda.offloaded_weight
|
||||
)
|
||||
assert module_to_cpu.weight.device.type == "cuda" and module_to_cuda.weight.device.type == "cpu"
|
||||
weight_swap_jobs.append(
|
||||
(module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data, offloaded_weight)
|
||||
)
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to offload
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream)
|
||||
offloaded_weight.copy_(module_to_cpu.weight.data, non_blocking=True)
|
||||
|
||||
stream.synchronize()
|
||||
|
||||
# cpu to cuda
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
# offload to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
|
||||
module_to_cpu.weight.data = offloaded_weight
|
||||
offloaded_weight = cpu_data_view
|
||||
module_to_cpu.offloaded_weight = offloaded_weight
|
||||
module_to_cuda.offloaded_weight = offloaded_weight
|
||||
|
||||
stream.synchronize()
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
|
||||
def swap_weight_devices_safe_works(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
if not (hasattr(module_to_cpu, "__cached_cpu_weight") or hasattr(module_to_cuda, "__cached_cuda_weight")):
|
||||
# one of the modules must have the tensor to cache
|
||||
module_to_cpu.__cached_cpu_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
|
||||
module_to_cpu.__cached_cpu_weight.pin_memory()
|
||||
|
||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
|
||||
for module_to_cpu, module_to_cuda, cuda_tensor_view, cpu_tensor_view in weight_swap_jobs:
|
||||
module_to_cpu.weight.data = cuda_tensor_view.to("cpu", non_blocking=True)
|
||||
module_to_cuda.weight.data = cpu_tensor_view.to("cuda", non_blocking=True)
|
||||
|
||||
torch.cuda.current_stream().synchronize() # wait for the copy from cache to cpu to finish
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
# def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
# assert module_to_cuda.weight.device.type == "cpu" and module_to_cpu.weight.device.type == "cuda"
|
||||
# weight_on_cuda = module_to_cpu.weight
|
||||
# weight_on_cpu = module_to_cuda.weight
|
||||
# cuda_to_cpu_data = weight_on_cuda.data.to("cpu", non_blocking=True)
|
||||
# event = torch.cuda.current_stream().record_event()
|
||||
# event.synchronize()
|
||||
# weight_on_cuda.data.copy_(weight_on_cpu.data, non_blocking=True)
|
||||
# weight_on_cpu.data = cuda_to_cpu_data
|
||||
# weight_on_cpu.grad, weight_on_cuda.grad = weight_on_cuda.grad, weight_on_cpu.grad
|
||||
|
||||
# module_to_cpu.weight = weight_on_cpu
|
||||
# module_to_cuda.weight = weight_on_cuda
|
||||
|
||||
|
||||
def weighs_to_device(layer: nn.Module, device: torch.device):
|
||||
for module in layer.modules():
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data = module.weight.data.to(device, non_blocking=True)
|
||||
|
||||
|
||||
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
|
||||
"""
|
||||
@@ -313,6 +533,7 @@ class MemoryEfficientSafeOpen:
|
||||
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
|
||||
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
|
||||
|
||||
|
||||
def load_safetensors(
|
||||
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
|
||||
) -> dict[str, torch.Tensor]:
|
||||
@@ -336,7 +557,6 @@ def load_safetensors(
|
||||
return state_dict
|
||||
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Image utils
|
||||
|
||||
Reference in New Issue
Block a user