mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
faster block swap
This commit is contained in:
103
flux_train.py
103
flux_train.py
@@ -17,12 +17,14 @@ import math
|
|||||||
import os
|
import os
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
import time
|
import time
|
||||||
from typing import List
|
from typing import List, Optional, Tuple, Union
|
||||||
import toml
|
import toml
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from library import utils
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
@@ -466,45 +468,28 @@ def train(args):
|
|||||||
|
|
||||||
# memory efficient block swapping
|
# memory efficient block swapping
|
||||||
|
|
||||||
def get_block_unit(dbl_blocks, sgl_blocks, index: int):
|
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id):
|
||||||
if index < len(dbl_blocks):
|
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
||||||
return (dbl_blocks[index],)
|
# start_time = time.perf_counter()
|
||||||
else:
|
# print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA")
|
||||||
index -= len(dbl_blocks)
|
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||||
index *= 2
|
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
|
||||||
return (sgl_blocks[index], sgl_blocks[index + 1])
|
return bidx_to_cpu, bidx_to_cuda # , event
|
||||||
|
|
||||||
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device):
|
block_to_cpu = blocks[block_idx_to_cpu]
|
||||||
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc):
|
block_to_cuda = blocks[block_idx_to_cuda]
|
||||||
# print(f"Backward: Move block {bidx_to_cpu} to CPU")
|
|
||||||
for block in blocks_to_cpu:
|
|
||||||
block = block.to("cpu", non_blocking=True)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# print(f"Backward: Move block {bidx_to_cuda} to CUDA")
|
futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
|
||||||
for block in blocks_to_cuda:
|
|
||||||
block = block.to(dvc, non_blocking=True)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
def wait_blocks_move(block_id, futures):
|
||||||
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}")
|
if block_id not in futures:
|
||||||
return bidx_to_cpu, bidx_to_cuda
|
|
||||||
|
|
||||||
blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu)
|
|
||||||
blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda)
|
|
||||||
|
|
||||||
futures[block_idx_to_cuda] = thread_pool.submit(
|
|
||||||
move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device
|
|
||||||
)
|
|
||||||
|
|
||||||
def wait_blocks_move(block_idx, futures):
|
|
||||||
if block_idx not in futures:
|
|
||||||
return
|
return
|
||||||
# print(f"Backward: Wait for block {block_idx}")
|
# print(f"Backward: Wait for block {block_id}")
|
||||||
# start_time = time.perf_counter()
|
# start_time = time.perf_counter()
|
||||||
future = futures.pop(block_idx)
|
future = futures.pop(block_id)
|
||||||
future.result()
|
_, bidx_to_cuda = future.result()
|
||||||
# print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
|
assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}"
|
||||||
# torch.cuda.synchronize()
|
# print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s")
|
||||||
# print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s")
|
# print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s")
|
||||||
|
|
||||||
if args.fused_backward_pass:
|
if args.fused_backward_pass:
|
||||||
@@ -513,11 +498,11 @@ def train(args):
|
|||||||
|
|
||||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||||
|
|
||||||
blocks_to_swap = args.blocks_to_swap
|
double_blocks_to_swap = args.blocks_to_swap // 2
|
||||||
|
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2
|
||||||
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
|
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
|
||||||
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
|
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
|
||||||
num_block_units = num_double_blocks + num_single_blocks // 2
|
handled_block_ids = set()
|
||||||
handled_unit_indices = set()
|
|
||||||
|
|
||||||
n = 1 # only asynchronous purpose, no need to increase this number
|
n = 1 # only asynchronous purpose, no need to increase this number
|
||||||
# n = 2
|
# n = 2
|
||||||
@@ -530,28 +515,37 @@ def train(args):
|
|||||||
if parameter.requires_grad:
|
if parameter.requires_grad:
|
||||||
grad_hook = None
|
grad_hook = None
|
||||||
|
|
||||||
if blocks_to_swap:
|
if double_blocks_to_swap > 0 or single_blocks_to_swap > 0:
|
||||||
is_double = param_name.startswith("double_blocks")
|
is_double = param_name.startswith("double_blocks")
|
||||||
is_single = param_name.startswith("single_blocks")
|
is_single = param_name.startswith("single_blocks")
|
||||||
if is_double or is_single:
|
if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0:
|
||||||
block_idx = int(param_name.split(".")[1])
|
block_idx = int(param_name.split(".")[1])
|
||||||
unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2
|
block_id = (is_double, block_idx) # double or single, block index
|
||||||
if unit_idx not in handled_unit_indices:
|
if block_id not in handled_block_ids:
|
||||||
# swap following (already backpropagated) block
|
# swap following (already backpropagated) block
|
||||||
handled_unit_indices.add(unit_idx)
|
handled_block_ids.add(block_id)
|
||||||
|
|
||||||
# if n blocks were already backpropagated
|
# if n blocks were already backpropagated
|
||||||
num_blocks_propagated = num_block_units - unit_idx - 1
|
if is_double:
|
||||||
|
num_blocks = num_double_blocks
|
||||||
|
blocks_to_swap = double_blocks_to_swap
|
||||||
|
else:
|
||||||
|
num_blocks = num_single_blocks
|
||||||
|
blocks_to_swap = single_blocks_to_swap
|
||||||
|
|
||||||
|
# -1 for 0-based index, -1 for current block is not fully backpropagated yet
|
||||||
|
num_blocks_propagated = num_blocks - block_idx - 2
|
||||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
|
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
|
||||||
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
|
waiting = block_idx > 0 and block_idx <= blocks_to_swap
|
||||||
|
|
||||||
if swapping or waiting:
|
if swapping or waiting:
|
||||||
block_idx_to_cpu = num_block_units - num_blocks_propagated
|
block_idx_to_cpu = num_blocks - num_blocks_propagated
|
||||||
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
|
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
|
||||||
block_idx_to_wait = unit_idx - 1
|
block_idx_to_wait = block_idx - 1
|
||||||
|
|
||||||
# create swap hook
|
# create swap hook
|
||||||
def create_swap_grad_hook(
|
def create_swap_grad_hook(
|
||||||
bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool
|
is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool
|
||||||
):
|
):
|
||||||
def __grad_hook(tensor: torch.Tensor):
|
def __grad_hook(tensor: torch.Tensor):
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
@@ -559,24 +553,25 @@ def train(args):
|
|||||||
optimizer.step_param(tensor, param_group)
|
optimizer.step_param(tensor, param_group)
|
||||||
tensor.grad = None
|
tensor.grad = None
|
||||||
|
|
||||||
# print(f"Backward: {uidx}, {swpng}, {wtng}")
|
# print(
|
||||||
|
# f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}"
|
||||||
|
# )
|
||||||
if swpng:
|
if swpng:
|
||||||
submit_move_blocks(
|
submit_move_blocks(
|
||||||
futures,
|
futures,
|
||||||
thread_pool,
|
thread_pool,
|
||||||
bidx_to_cpu,
|
bidx_to_cpu,
|
||||||
bidx_to_cuda,
|
bidx_to_cuda,
|
||||||
flux.double_blocks,
|
flux.double_blocks if is_dbl else flux.single_blocks,
|
||||||
flux.single_blocks,
|
(is_dbl, bidx_to_cuda), # wait for this block
|
||||||
accelerator.device,
|
|
||||||
)
|
)
|
||||||
if wtng:
|
if wtng:
|
||||||
wait_blocks_move(bidx_to_wait, futures)
|
wait_blocks_move((is_dbl, bidx_to_wait), futures)
|
||||||
|
|
||||||
return __grad_hook
|
return __grad_hook
|
||||||
|
|
||||||
grad_hook = create_swap_grad_hook(
|
grad_hook = create_swap_grad_hook(
|
||||||
block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting
|
is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting
|
||||||
)
|
)
|
||||||
|
|
||||||
if grad_hook is None:
|
if grad_hook is None:
|
||||||
|
|||||||
@@ -7,8 +7,9 @@ from dataclasses import dataclass
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from library import utils
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
@@ -923,7 +924,8 @@ class Flux(nn.Module):
|
|||||||
self.blocks_to_swap = None
|
self.blocks_to_swap = None
|
||||||
|
|
||||||
self.thread_pool: Optional[ThreadPoolExecutor] = None
|
self.thread_pool: Optional[ThreadPoolExecutor] = None
|
||||||
self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2
|
self.num_double_blocks = len(self.double_blocks)
|
||||||
|
self.num_single_blocks = len(self.single_blocks)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
@@ -963,14 +965,17 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
def enable_block_swap(self, num_blocks: int):
|
def enable_block_swap(self, num_blocks: int):
|
||||||
self.blocks_to_swap = num_blocks
|
self.blocks_to_swap = num_blocks
|
||||||
|
self.double_blocks_to_swap = num_blocks // 2
|
||||||
|
self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2
|
||||||
|
print(
|
||||||
|
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}."
|
||||||
|
)
|
||||||
|
|
||||||
n = 1 # async block swap. 1 is enough
|
n = 1 # async block swap. 1 is enough
|
||||||
# n = 2
|
|
||||||
# n = max(1, os.cpu_count() // 2)
|
|
||||||
self.thread_pool = ThreadPoolExecutor(max_workers=n)
|
self.thread_pool = ThreadPoolExecutor(max_workers=n)
|
||||||
|
|
||||||
def move_to_device_except_swap_blocks(self, device: torch.device):
|
def move_to_device_except_swap_blocks(self, device: torch.device):
|
||||||
# assume model is on cpu
|
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
||||||
if self.blocks_to_swap:
|
if self.blocks_to_swap:
|
||||||
save_double_blocks = self.double_blocks
|
save_double_blocks = self.double_blocks
|
||||||
save_single_blocks = self.single_blocks
|
save_single_blocks = self.single_blocks
|
||||||
@@ -983,31 +988,55 @@ class Flux(nn.Module):
|
|||||||
self.double_blocks = save_double_blocks
|
self.double_blocks = save_double_blocks
|
||||||
self.single_blocks = save_single_blocks
|
self.single_blocks = save_single_blocks
|
||||||
|
|
||||||
def get_block_unit(self, index: int):
|
# def get_block_unit(self, index: int):
|
||||||
if index < len(self.double_blocks):
|
# if index < len(self.double_blocks):
|
||||||
return (self.double_blocks[index],)
|
# return (self.double_blocks[index],)
|
||||||
else:
|
# else:
|
||||||
index -= len(self.double_blocks)
|
# index -= len(self.double_blocks)
|
||||||
index *= 2
|
# index *= 2
|
||||||
return self.single_blocks[index], self.single_blocks[index + 1]
|
# return self.single_blocks[index], self.single_blocks[index + 1]
|
||||||
|
|
||||||
def get_unit_index(self, is_double: bool, index: int):
|
# def get_unit_index(self, is_double: bool, index: int):
|
||||||
if is_double:
|
# if is_double:
|
||||||
return index
|
# return index
|
||||||
else:
|
# else:
|
||||||
return len(self.double_blocks) + index // 2
|
# return len(self.double_blocks) + index // 2
|
||||||
|
|
||||||
def prepare_block_swap_before_forward(self):
|
def prepare_block_swap_before_forward(self):
|
||||||
# make: first n blocks are on cuda, and last n blocks are on cpu
|
# # make: first n blocks are on cuda, and last n blocks are on cpu
|
||||||
|
# if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||||
|
# # raise ValueError("Block swap is not enabled.")
|
||||||
|
# return
|
||||||
|
# for i in range(self.num_block_units - self.blocks_to_swap):
|
||||||
|
# for b in self.get_block_unit(i):
|
||||||
|
# b.to(self.device)
|
||||||
|
# for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
|
||||||
|
# for b in self.get_block_unit(i):
|
||||||
|
# b.to("cpu")
|
||||||
|
# clean_memory_on_device(self.device)
|
||||||
|
|
||||||
|
# all blocks are on device, but some weights are on cpu
|
||||||
|
# make first n blocks weights on device, and last n blocks weights on cpu
|
||||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||||
# raise ValueError("Block swap is not enabled.")
|
# raise ValueError("Block swap is not enabled.")
|
||||||
return
|
return
|
||||||
for i in range(self.num_block_units - self.blocks_to_swap):
|
|
||||||
for b in self.get_block_unit(i):
|
for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]:
|
||||||
b.to(self.device)
|
b.to(self.device)
|
||||||
for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
|
utils.weighs_to_device(b, self.device) # make sure weights are on device
|
||||||
for b in self.get_block_unit(i):
|
for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]:
|
||||||
b.to("cpu")
|
b.to(self.device) # move block to device first
|
||||||
|
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
clean_memory_on_device(self.device)
|
||||||
|
|
||||||
|
for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]:
|
||||||
|
b.to(self.device)
|
||||||
|
utils.weighs_to_device(b, self.device) # make sure weights are on device
|
||||||
|
for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]:
|
||||||
|
b.to(self.device) # move block to device first
|
||||||
|
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
|
||||||
|
torch.cuda.synchronize()
|
||||||
clean_memory_on_device(self.device)
|
clean_memory_on_device(self.device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1044,27 +1073,22 @@ class Flux(nn.Module):
|
|||||||
for block in self.single_blocks:
|
for block in self.single_blocks:
|
||||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
else:
|
else:
|
||||||
futures = {}
|
# device = self.device
|
||||||
|
|
||||||
def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda):
|
def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda):
|
||||||
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda):
|
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
||||||
# print(f"Moving {bidx_to_cpu} to cpu.")
|
start_time = time.perf_counter()
|
||||||
for block in blocks_to_cpu:
|
# print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.")
|
||||||
block.to("cpu", non_blocking=True)
|
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# print(f"Moving {bidx_to_cuda} to cuda.")
|
|
||||||
for block in blocks_to_cuda:
|
|
||||||
block.to(self.device, non_blocking=True)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
|
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
|
||||||
return block_idx_to_cpu, block_idx_to_cuda
|
|
||||||
|
|
||||||
blocks_to_cpu = self.get_block_unit(block_idx_to_cpu)
|
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
|
||||||
blocks_to_cuda = self.get_block_unit(block_idx_to_cuda)
|
return block_idx_to_cpu, block_idx_to_cuda # , event
|
||||||
|
|
||||||
|
block_to_cpu = blocks[block_idx_to_cpu]
|
||||||
|
block_to_cuda = blocks[block_idx_to_cuda]
|
||||||
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
|
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
|
||||||
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda)
|
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
|
||||||
|
|
||||||
def wait_for_blocks_move(block_idx, ftrs):
|
def wait_for_blocks_move(block_idx, ftrs):
|
||||||
if block_idx not in ftrs:
|
if block_idx not in ftrs:
|
||||||
@@ -1073,37 +1097,35 @@ class Flux(nn.Module):
|
|||||||
# start_time = time.perf_counter()
|
# start_time = time.perf_counter()
|
||||||
ftr = ftrs.pop(block_idx)
|
ftr = ftrs.pop(block_idx)
|
||||||
ftr.result()
|
ftr.result()
|
||||||
# torch.cuda.synchronize()
|
# print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds")
|
||||||
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
|
|
||||||
|
|
||||||
|
double_futures = {}
|
||||||
for block_idx, block in enumerate(self.double_blocks):
|
for block_idx, block in enumerate(self.double_blocks):
|
||||||
# print(f"Double block {block_idx}")
|
# print(f"Double block {block_idx}")
|
||||||
unit_idx = self.get_unit_index(is_double=True, index=block_idx)
|
wait_for_blocks_move(block_idx, double_futures)
|
||||||
wait_for_blocks_move(unit_idx, futures)
|
|
||||||
|
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
|
|
||||||
if unit_idx < self.blocks_to_swap:
|
if block_idx < self.double_blocks_to_swap:
|
||||||
block_idx_to_cpu = unit_idx
|
block_idx_to_cpu = block_idx
|
||||||
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
|
block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx
|
||||||
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
|
future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda)
|
||||||
futures[block_idx_to_cuda] = future
|
double_futures[block_idx_to_cuda] = future
|
||||||
|
|
||||||
img = torch.cat((txt, img), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
|
single_futures = {}
|
||||||
for block_idx, block in enumerate(self.single_blocks):
|
for block_idx, block in enumerate(self.single_blocks):
|
||||||
# print(f"Single block {block_idx}")
|
# print(f"Single block {block_idx}")
|
||||||
unit_idx = self.get_unit_index(is_double=False, index=block_idx)
|
wait_for_blocks_move(block_idx, single_futures)
|
||||||
if block_idx % 2 == 0:
|
|
||||||
wait_for_blocks_move(unit_idx, futures)
|
|
||||||
|
|
||||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
|
|
||||||
if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap:
|
if block_idx < self.single_blocks_to_swap:
|
||||||
block_idx_to_cpu = unit_idx
|
block_idx_to_cpu = block_idx
|
||||||
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
|
block_idx_to_cuda = self.num_single_blocks - self.blocks_to_swap + block_idx
|
||||||
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
|
future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda)
|
||||||
futures[block_idx_to_cuda] = future
|
single_futures[block_idx_to_cuda] = future
|
||||||
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
|
||||||
|
|||||||
222
library/utils.py
222
library/utils.py
@@ -6,6 +6,7 @@ import json
|
|||||||
import struct
|
import struct
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from diffusers import EulerAncestralDiscreteScheduler
|
from diffusers import EulerAncestralDiscreteScheduler
|
||||||
import diffusers.schedulers.scheduling_euler_ancestral_discrete
|
import diffusers.schedulers.scheduling_euler_ancestral_discrete
|
||||||
@@ -93,6 +94,225 @@ def setup_logging(args=None, log_level=None, reset=False):
|
|||||||
|
|
||||||
# region PyTorch utils
|
# region PyTorch utils
|
||||||
|
|
||||||
|
# def swap_weights(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||||
|
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||||
|
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||||
|
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||||
|
# # print(f"Swapping {layer_to_cpu.__class__.__name__}-{module_to_cpu.__class__.__name__}.")
|
||||||
|
# # cpu_tensor = module_to_cuda.weight.data
|
||||||
|
# # cuda_tensor = module_to_cpu.weight.data
|
||||||
|
# # assert cuda_tensor.device.type == "cuda"
|
||||||
|
# # temp_cpu_tensor = cuda_tensor.to("cpu", non_blocking=True)
|
||||||
|
# # torch.cuda.current_stream().synchronize()
|
||||||
|
# # cuda_tensor.copy_(cpu_tensor, non_blocking=True)
|
||||||
|
# # torch.cuda.current_stream().synchronize()
|
||||||
|
# # cpu_tensor.copy_(temp_cpu_tensor, non_blocking=True)
|
||||||
|
# # module_to_cpu.weight.data, module_to_cuda.weight.data = cpu_tensor, cuda_tensor
|
||||||
|
# cuda_tensor_view = module_to_cpu.weight.data
|
||||||
|
# cpu_tensor_view = module_to_cuda.weight.data
|
||||||
|
# module_to_cpu.weight.data = module_to_cpu.weight.to("cpu", non_blocking=True).detach().clone()
|
||||||
|
# module_to_cuda.weight.data = cuda_tensor_view
|
||||||
|
# module_to_cuda.weight.data.copy_(cpu_tensor_view)
|
||||||
|
|
||||||
|
|
||||||
|
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||||
|
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||||
|
|
||||||
|
weight_swap_jobs = []
|
||||||
|
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||||
|
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||||
|
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||||
|
|
||||||
|
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||||
|
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
# cuda to cpu
|
||||||
|
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||||
|
cuda_data_view.record_stream(stream)
|
||||||
|
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||||
|
|
||||||
|
stream.synchronize()
|
||||||
|
|
||||||
|
# cpu to cuda
|
||||||
|
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||||
|
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||||
|
module_to_cuda.weight.data = cuda_data_view
|
||||||
|
|
||||||
|
stream.synchronize()
|
||||||
|
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||||
|
|
||||||
|
|
||||||
|
def swap_weight_devices_2st(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||||
|
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||||
|
|
||||||
|
weight_swap_jobs = []
|
||||||
|
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||||
|
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||||
|
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||||
|
|
||||||
|
stream_to_cpu = torch.cuda.Stream()
|
||||||
|
stream_to_cuda = torch.cuda.Stream()
|
||||||
|
|
||||||
|
events = []
|
||||||
|
with torch.cuda.stream(stream_to_cpu):
|
||||||
|
# cuda to offload
|
||||||
|
offloaded_weights = []
|
||||||
|
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||||
|
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))
|
||||||
|
event = torch.cuda.Event()
|
||||||
|
event.record(stream=stream_to_cpu)
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
with torch.cuda.stream(stream_to_cuda):
|
||||||
|
# cpu to cuda
|
||||||
|
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event in zip(weight_swap_jobs, events):
|
||||||
|
event.synchronize()
|
||||||
|
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||||
|
module_to_cuda.weight.data = cuda_data_view
|
||||||
|
|
||||||
|
# offload to cpu
|
||||||
|
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), offloaded_weight in zip(
|
||||||
|
weight_swap_jobs, offloaded_weights
|
||||||
|
):
|
||||||
|
module_to_cpu.weight.data = offloaded_weight
|
||||||
|
|
||||||
|
stream_to_cuda.synchronize()
|
||||||
|
|
||||||
|
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||||
|
|
||||||
|
|
||||||
|
def swap_weight_devices_failed(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||||
|
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||||
|
|
||||||
|
weight_swap_jobs = []
|
||||||
|
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||||
|
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||||
|
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||||
|
|
||||||
|
stream_to_cpu = torch.cuda.Stream()
|
||||||
|
stream_to_cuda = torch.cuda.Stream()
|
||||||
|
|
||||||
|
# cuda to offload
|
||||||
|
events = []
|
||||||
|
with torch.cuda.stream(stream_to_cpu):
|
||||||
|
offloaded_weights = []
|
||||||
|
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||||
|
cuda_data_view.record_stream(stream_to_cpu)
|
||||||
|
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))
|
||||||
|
|
||||||
|
event = torch.cuda.Event()
|
||||||
|
event.record(stream=stream_to_cpu)
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
# cpu to cuda
|
||||||
|
with torch.cuda.stream(stream_to_cuda):
|
||||||
|
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event, offloaded_weight in zip(
|
||||||
|
weight_swap_jobs, events, offloaded_weights
|
||||||
|
):
|
||||||
|
event.synchronize()
|
||||||
|
cuda_data_view.record_stream(stream_to_cuda)
|
||||||
|
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||||
|
module_to_cuda.weight.data = cuda_data_view
|
||||||
|
|
||||||
|
module_to_cpu.weight.data = offloaded_weight
|
||||||
|
|
||||||
|
stream_to_cuda.synchronize()
|
||||||
|
|
||||||
|
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||||
|
# torch.cuda.current_stream().wait_stream(stream_to_cuda)
|
||||||
|
# for job in weight_swap_jobs:
|
||||||
|
# job[2].record_stream(torch.cuda.current_stream()) # record the ownership of the tensor
|
||||||
|
|
||||||
|
|
||||||
|
def swap_weight_devices_works_2(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||||
|
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||||
|
|
||||||
|
weight_swap_jobs = []
|
||||||
|
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||||
|
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||||
|
if not (hasattr(module_to_cpu, "offloaded_weight") or hasattr(module_to_cuda, "offloaded_weight")):
|
||||||
|
# one of the modules must have the tensor to offload
|
||||||
|
module_to_cpu.offloaded_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
|
||||||
|
module_to_cpu.offloaded_weight.pin_memory()
|
||||||
|
offloaded_weight = (
|
||||||
|
module_to_cpu.offloaded_weight if hasattr(module_to_cpu, "offloaded_weight") else module_to_cuda.offloaded_weight
|
||||||
|
)
|
||||||
|
assert module_to_cpu.weight.device.type == "cuda" and module_to_cuda.weight.device.type == "cpu"
|
||||||
|
weight_swap_jobs.append(
|
||||||
|
(module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data, offloaded_weight)
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
# cuda to offload
|
||||||
|
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
|
||||||
|
cuda_data_view.record_stream(stream)
|
||||||
|
offloaded_weight.copy_(module_to_cpu.weight.data, non_blocking=True)
|
||||||
|
|
||||||
|
stream.synchronize()
|
||||||
|
|
||||||
|
# cpu to cuda
|
||||||
|
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
|
||||||
|
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||||
|
module_to_cuda.weight.data = cuda_data_view
|
||||||
|
|
||||||
|
# offload to cpu
|
||||||
|
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
|
||||||
|
module_to_cpu.weight.data = offloaded_weight
|
||||||
|
offloaded_weight = cpu_data_view
|
||||||
|
module_to_cpu.offloaded_weight = offloaded_weight
|
||||||
|
module_to_cuda.offloaded_weight = offloaded_weight
|
||||||
|
|
||||||
|
stream.synchronize()
|
||||||
|
|
||||||
|
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||||
|
|
||||||
|
|
||||||
|
def swap_weight_devices_safe_works(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||||
|
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||||
|
|
||||||
|
weight_swap_jobs = []
|
||||||
|
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||||
|
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||||
|
if not (hasattr(module_to_cpu, "__cached_cpu_weight") or hasattr(module_to_cuda, "__cached_cuda_weight")):
|
||||||
|
# one of the modules must have the tensor to cache
|
||||||
|
module_to_cpu.__cached_cpu_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
|
||||||
|
module_to_cpu.__cached_cpu_weight.pin_memory()
|
||||||
|
|
||||||
|
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||||
|
|
||||||
|
for module_to_cpu, module_to_cuda, cuda_tensor_view, cpu_tensor_view in weight_swap_jobs:
|
||||||
|
module_to_cpu.weight.data = cuda_tensor_view.to("cpu", non_blocking=True)
|
||||||
|
module_to_cuda.weight.data = cpu_tensor_view.to("cuda", non_blocking=True)
|
||||||
|
|
||||||
|
torch.cuda.current_stream().synchronize() # wait for the copy from cache to cpu to finish
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
# def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||||
|
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||||
|
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||||
|
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||||
|
# assert module_to_cuda.weight.device.type == "cpu" and module_to_cpu.weight.device.type == "cuda"
|
||||||
|
# weight_on_cuda = module_to_cpu.weight
|
||||||
|
# weight_on_cpu = module_to_cuda.weight
|
||||||
|
# cuda_to_cpu_data = weight_on_cuda.data.to("cpu", non_blocking=True)
|
||||||
|
# event = torch.cuda.current_stream().record_event()
|
||||||
|
# event.synchronize()
|
||||||
|
# weight_on_cuda.data.copy_(weight_on_cpu.data, non_blocking=True)
|
||||||
|
# weight_on_cpu.data = cuda_to_cpu_data
|
||||||
|
# weight_on_cpu.grad, weight_on_cuda.grad = weight_on_cuda.grad, weight_on_cpu.grad
|
||||||
|
|
||||||
|
# module_to_cpu.weight = weight_on_cpu
|
||||||
|
# module_to_cuda.weight = weight_on_cuda
|
||||||
|
|
||||||
|
|
||||||
|
def weighs_to_device(layer: nn.Module, device: torch.device):
|
||||||
|
for module in layer.modules():
|
||||||
|
if hasattr(module, "weight") and module.weight is not None:
|
||||||
|
module.weight.data = module.weight.data.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
|
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
|
||||||
"""
|
"""
|
||||||
@@ -313,6 +533,7 @@ class MemoryEfficientSafeOpen:
|
|||||||
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
|
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
|
||||||
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
|
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
|
||||||
|
|
||||||
|
|
||||||
def load_safetensors(
|
def load_safetensors(
|
||||||
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
|
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
|
||||||
) -> dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
@@ -336,7 +557,6 @@ def load_safetensors(
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Image utils
|
# region Image utils
|
||||||
|
|||||||
Reference in New Issue
Block a user