faster block swap

This commit is contained in:
Kohya S
2024-11-05 21:22:42 +09:00
parent 5e32ee26a1
commit 81c0c965a2
3 changed files with 350 additions and 113 deletions

View File

@@ -17,12 +17,14 @@ import math
import os
from multiprocessing import Value
import time
from typing import List
from typing import List, Optional, Tuple, Union
import toml
from tqdm import tqdm
import torch
import torch.nn as nn
from library import utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -466,45 +468,28 @@ def train(args):
# memory efficient block swapping
def get_block_unit(dbl_blocks, sgl_blocks, index: int):
if index < len(dbl_blocks):
return (dbl_blocks[index],)
else:
index -= len(dbl_blocks)
index *= 2
return (sgl_blocks[index], sgl_blocks[index + 1])
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
# start_time = time.perf_counter()
# print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA")
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
return bidx_to_cpu, bidx_to_cuda # , event
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device):
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc):
# print(f"Backward: Move block {bidx_to_cpu} to CPU")
for block in blocks_to_cpu:
block = block.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
block_to_cpu = blocks[block_idx_to_cpu]
block_to_cuda = blocks[block_idx_to_cuda]
# print(f"Backward: Move block {bidx_to_cuda} to CUDA")
for block in blocks_to_cuda:
block = block.to(dvc, non_blocking=True)
futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
torch.cuda.synchronize()
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}")
return bidx_to_cpu, bidx_to_cuda
blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu)
blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda)
futures[block_idx_to_cuda] = thread_pool.submit(
move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device
)
def wait_blocks_move(block_idx, futures):
if block_idx not in futures:
def wait_blocks_move(block_id, futures):
if block_id not in futures:
return
# print(f"Backward: Wait for block {block_idx}")
# print(f"Backward: Wait for block {block_id}")
# start_time = time.perf_counter()
future = futures.pop(block_idx)
future.result()
# print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
# torch.cuda.synchronize()
future = futures.pop(block_id)
_, bidx_to_cuda = future.result()
assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}"
# print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s")
# print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s")
if args.fused_backward_pass:
@@ -513,11 +498,11 @@ def train(args):
library.adafactor_fused.patch_adafactor_fused(optimizer)
blocks_to_swap = args.blocks_to_swap
double_blocks_to_swap = args.blocks_to_swap // 2
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2
handled_unit_indices = set()
handled_block_ids = set()
n = 1 # only asynchronous purpose, no need to increase this number
# n = 2
@@ -530,28 +515,37 @@ def train(args):
if parameter.requires_grad:
grad_hook = None
if blocks_to_swap:
if double_blocks_to_swap > 0 or single_blocks_to_swap > 0:
is_double = param_name.startswith("double_blocks")
is_single = param_name.startswith("single_blocks")
if is_double or is_single:
if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0:
block_idx = int(param_name.split(".")[1])
unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2
if unit_idx not in handled_unit_indices:
block_id = (is_double, block_idx) # double or single, block index
if block_id not in handled_block_ids:
# swap following (already backpropagated) block
handled_unit_indices.add(unit_idx)
handled_block_ids.add(block_id)
# if n blocks were already backpropagated
num_blocks_propagated = num_block_units - unit_idx - 1
if is_double:
num_blocks = num_double_blocks
blocks_to_swap = double_blocks_to_swap
else:
num_blocks = num_single_blocks
blocks_to_swap = single_blocks_to_swap
# -1 for 0-based index, -1 for current block is not fully backpropagated yet
num_blocks_propagated = num_blocks - block_idx - 2
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
waiting = block_idx > 0 and block_idx <= blocks_to_swap
if swapping or waiting:
block_idx_to_cpu = num_block_units - num_blocks_propagated
block_idx_to_cpu = num_blocks - num_blocks_propagated
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
block_idx_to_wait = unit_idx - 1
block_idx_to_wait = block_idx - 1
# create swap hook
def create_swap_grad_hook(
bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool
is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool
):
def __grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -559,24 +553,25 @@ def train(args):
optimizer.step_param(tensor, param_group)
tensor.grad = None
# print(f"Backward: {uidx}, {swpng}, {wtng}")
# print(
# f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}"
# )
if swpng:
submit_move_blocks(
futures,
thread_pool,
bidx_to_cpu,
bidx_to_cuda,
flux.double_blocks,
flux.single_blocks,
accelerator.device,
flux.double_blocks if is_dbl else flux.single_blocks,
(is_dbl, bidx_to_cuda), # wait for this block
)
if wtng:
wait_blocks_move(bidx_to_wait, futures)
wait_blocks_move((is_dbl, bidx_to_wait), futures)
return __grad_hook
grad_hook = create_swap_grad_hook(
block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting
is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting
)
if grad_hook is None: