mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
faster block swap
This commit is contained in:
103
flux_train.py
103
flux_train.py
@@ -17,12 +17,14 @@ import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
import time
|
||||
from typing import List
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from library import utils
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
@@ -466,45 +468,28 @@ def train(args):
|
||||
|
||||
# memory efficient block swapping
|
||||
|
||||
def get_block_unit(dbl_blocks, sgl_blocks, index: int):
|
||||
if index < len(dbl_blocks):
|
||||
return (dbl_blocks[index],)
|
||||
else:
|
||||
index -= len(dbl_blocks)
|
||||
index *= 2
|
||||
return (sgl_blocks[index], sgl_blocks[index + 1])
|
||||
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id):
|
||||
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
||||
# start_time = time.perf_counter()
|
||||
# print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA")
|
||||
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
|
||||
return bidx_to_cpu, bidx_to_cuda # , event
|
||||
|
||||
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device):
|
||||
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc):
|
||||
# print(f"Backward: Move block {bidx_to_cpu} to CPU")
|
||||
for block in blocks_to_cpu:
|
||||
block = block.to("cpu", non_blocking=True)
|
||||
torch.cuda.empty_cache()
|
||||
block_to_cpu = blocks[block_idx_to_cpu]
|
||||
block_to_cuda = blocks[block_idx_to_cuda]
|
||||
|
||||
# print(f"Backward: Move block {bidx_to_cuda} to CUDA")
|
||||
for block in blocks_to_cuda:
|
||||
block = block.to(dvc, non_blocking=True)
|
||||
futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}")
|
||||
return bidx_to_cpu, bidx_to_cuda
|
||||
|
||||
blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu)
|
||||
blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda)
|
||||
|
||||
futures[block_idx_to_cuda] = thread_pool.submit(
|
||||
move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device
|
||||
)
|
||||
|
||||
def wait_blocks_move(block_idx, futures):
|
||||
if block_idx not in futures:
|
||||
def wait_blocks_move(block_id, futures):
|
||||
if block_id not in futures:
|
||||
return
|
||||
# print(f"Backward: Wait for block {block_idx}")
|
||||
# print(f"Backward: Wait for block {block_id}")
|
||||
# start_time = time.perf_counter()
|
||||
future = futures.pop(block_idx)
|
||||
future.result()
|
||||
# print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
|
||||
# torch.cuda.synchronize()
|
||||
future = futures.pop(block_id)
|
||||
_, bidx_to_cuda = future.result()
|
||||
assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}"
|
||||
# print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s")
|
||||
# print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s")
|
||||
|
||||
if args.fused_backward_pass:
|
||||
@@ -513,11 +498,11 @@ def train(args):
|
||||
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
|
||||
blocks_to_swap = args.blocks_to_swap
|
||||
double_blocks_to_swap = args.blocks_to_swap // 2
|
||||
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2
|
||||
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
|
||||
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
|
||||
num_block_units = num_double_blocks + num_single_blocks // 2
|
||||
handled_unit_indices = set()
|
||||
handled_block_ids = set()
|
||||
|
||||
n = 1 # only asynchronous purpose, no need to increase this number
|
||||
# n = 2
|
||||
@@ -530,28 +515,37 @@ def train(args):
|
||||
if parameter.requires_grad:
|
||||
grad_hook = None
|
||||
|
||||
if blocks_to_swap:
|
||||
if double_blocks_to_swap > 0 or single_blocks_to_swap > 0:
|
||||
is_double = param_name.startswith("double_blocks")
|
||||
is_single = param_name.startswith("single_blocks")
|
||||
if is_double or is_single:
|
||||
if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0:
|
||||
block_idx = int(param_name.split(".")[1])
|
||||
unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2
|
||||
if unit_idx not in handled_unit_indices:
|
||||
block_id = (is_double, block_idx) # double or single, block index
|
||||
if block_id not in handled_block_ids:
|
||||
# swap following (already backpropagated) block
|
||||
handled_unit_indices.add(unit_idx)
|
||||
handled_block_ids.add(block_id)
|
||||
|
||||
# if n blocks were already backpropagated
|
||||
num_blocks_propagated = num_block_units - unit_idx - 1
|
||||
if is_double:
|
||||
num_blocks = num_double_blocks
|
||||
blocks_to_swap = double_blocks_to_swap
|
||||
else:
|
||||
num_blocks = num_single_blocks
|
||||
blocks_to_swap = single_blocks_to_swap
|
||||
|
||||
# -1 for 0-based index, -1 for current block is not fully backpropagated yet
|
||||
num_blocks_propagated = num_blocks - block_idx - 2
|
||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
|
||||
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
|
||||
waiting = block_idx > 0 and block_idx <= blocks_to_swap
|
||||
|
||||
if swapping or waiting:
|
||||
block_idx_to_cpu = num_block_units - num_blocks_propagated
|
||||
block_idx_to_cpu = num_blocks - num_blocks_propagated
|
||||
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
|
||||
block_idx_to_wait = unit_idx - 1
|
||||
block_idx_to_wait = block_idx - 1
|
||||
|
||||
# create swap hook
|
||||
def create_swap_grad_hook(
|
||||
bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool
|
||||
is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool
|
||||
):
|
||||
def __grad_hook(tensor: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
@@ -559,24 +553,25 @@ def train(args):
|
||||
optimizer.step_param(tensor, param_group)
|
||||
tensor.grad = None
|
||||
|
||||
# print(f"Backward: {uidx}, {swpng}, {wtng}")
|
||||
# print(
|
||||
# f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}"
|
||||
# )
|
||||
if swpng:
|
||||
submit_move_blocks(
|
||||
futures,
|
||||
thread_pool,
|
||||
bidx_to_cpu,
|
||||
bidx_to_cuda,
|
||||
flux.double_blocks,
|
||||
flux.single_blocks,
|
||||
accelerator.device,
|
||||
flux.double_blocks if is_dbl else flux.single_blocks,
|
||||
(is_dbl, bidx_to_cuda), # wait for this block
|
||||
)
|
||||
if wtng:
|
||||
wait_blocks_move(bidx_to_wait, futures)
|
||||
wait_blocks_move((is_dbl, bidx_to_wait), futures)
|
||||
|
||||
return __grad_hook
|
||||
|
||||
grad_hook = create_swap_grad_hook(
|
||||
block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting
|
||||
is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting
|
||||
)
|
||||
|
||||
if grad_hook is None:
|
||||
|
||||
Reference in New Issue
Block a user