new block swap for FLUX.1 fine tuning

This commit is contained in:
Kohya S
2024-09-26 08:26:31 +09:00
parent 65fb69f808
commit 56a7bc171d
3 changed files with 294 additions and 166 deletions

View File

@@ -11,10 +11,12 @@
# - Per-block fused optimizer instances
import argparse
from concurrent.futures import ThreadPoolExecutor
import copy
import math
import os
from multiprocessing import Value
import time
from typing import List
import toml
@@ -265,14 +267,30 @@ def train(args):
flux.requires_grad_(True)
is_swapping_blocks = args.double_blocks_to_swap or args.single_blocks_to_swap
# block swap
# backward compatibility
if args.blocks_to_swap is None:
blocks_to_swap = args.double_blocks_to_swap or 0
if args.single_blocks_to_swap is not None:
blocks_to_swap += args.single_blocks_to_swap // 2
if blocks_to_swap > 0:
logger.warning(
"double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead."
" / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。"
)
logger.info(
f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}."
)
args.blocks_to_swap = blocks_to_swap
del blocks_to_swap
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you!
logger.info(
f"enable block swap: double_blocks_to_swap={args.double_blocks_to_swap}, single_blocks_to_swap={args.single_blocks_to_swap}"
)
flux.enable_block_swap(args.double_blocks_to_swap, args.single_blocks_to_swap)
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
flux.enable_block_swap(args.blocks_to_swap)
if not cache_latents:
# load VAE here if not cached
@@ -443,82 +461,120 @@ def train(args):
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# memory efficient block swapping
def get_block_unit(dbl_blocks, sgl_blocks, index: int):
if index < len(dbl_blocks):
return (dbl_blocks[index],)
else:
index -= len(dbl_blocks)
index *= 2
return (sgl_blocks[index], sgl_blocks[index + 1])
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device):
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc):
# print(f"Backward: Move block {bidx_to_cpu} to CPU")
for block in blocks_to_cpu:
block = block.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
# print(f"Backward: Move block {bidx_to_cuda} to CUDA")
for block in blocks_to_cuda:
block = block.to(dvc, non_blocking=True)
torch.cuda.synchronize()
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}")
return bidx_to_cpu, bidx_to_cuda
blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu)
blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda)
futures[block_idx_to_cuda] = thread_pool.submit(
move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device
)
def wait_blocks_move(block_idx, futures):
if block_idx not in futures:
return
# print(f"Backward: Wait for block {block_idx}")
# start_time = time.perf_counter()
future = futures.pop(block_idx)
future.result()
# print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
# torch.cuda.synchronize()
# print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s")
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
double_blocks_to_swap = args.double_blocks_to_swap
single_blocks_to_swap = args.single_blocks_to_swap
blocks_to_swap = args.blocks_to_swap
num_double_blocks = 19 # len(flux.double_blocks)
num_single_blocks = 38 # len(flux.single_blocks)
handled_double_block_indices = set()
handled_single_block_indices = set()
num_block_units = num_double_blocks + num_single_blocks // 2
handled_unit_indices = set()
n = 1 # only asyncronous purpose, no need to increase this number
# n = 2
# n = max(1, os.cpu_count() // 2)
thread_pool = ThreadPoolExecutor(max_workers=n)
futures = {}
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
grad_hook = None
if double_blocks_to_swap:
if param_name.startswith("double_blocks"):
if blocks_to_swap:
is_double = param_name.startswith("double_blocks")
is_single = param_name.startswith("single_blocks")
if is_double or is_single:
block_idx = int(param_name.split(".")[1])
if (
block_idx not in handled_double_block_indices
and block_idx >= (num_double_blocks - double_blocks_to_swap) - 1
and block_idx < num_double_blocks - 1
):
# swap next (already backpropagated) block
handled_double_block_indices.add(block_idx)
block_idx_cpu = block_idx + 1
block_idx_cuda = double_blocks_to_swap - (num_double_blocks - block_idx_cpu)
unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2
if unit_idx not in handled_unit_indices:
# swap following (already backpropagated) block
handled_unit_indices.add(unit_idx)
# create swap hook
def create_double_swap_grad_hook(bidx, bidx_cuda):
def __grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
# if n blocks were already backpropagated
num_blocks_propagated = num_block_units - unit_idx - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
if swapping or waiting:
block_idx_to_cpu = num_block_units - num_blocks_propagated
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
block_idx_to_wait = unit_idx - 1
# swap blocks if necessary
flux.double_blocks[bidx].to("cpu")
flux.double_blocks[bidx_cuda].to(accelerator.device)
# print(f"Move double block {bidx} to cpu and {bidx_cuda} to device")
# create swap hook
def create_swap_grad_hook(
bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool
):
def __grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
return __grad_hook
# print(f"Backward: {uidx}, {swpng}, {wtng}")
if swpng:
submit_move_blocks(
futures,
thread_pool,
bidx_to_cpu,
bidx_to_cuda,
flux.double_blocks,
flux.single_blocks,
accelerator.device,
)
if wtng:
wait_blocks_move(bidx_to_wait, futures)
grad_hook = create_double_swap_grad_hook(block_idx_cpu, block_idx_cuda)
if single_blocks_to_swap:
if param_name.startswith("single_blocks"):
block_idx = int(param_name.split(".")[1])
if (
block_idx not in handled_single_block_indices
and block_idx >= (num_single_blocks - single_blocks_to_swap) - 1
and block_idx < num_single_blocks - 1
):
handled_single_block_indices.add(block_idx)
block_idx_cpu = block_idx + 1
block_idx_cuda = single_blocks_to_swap - (num_single_blocks - block_idx_cpu)
# print(param_name, block_idx_cpu, block_idx_cuda)
return __grad_hook
# create swap hook
def create_single_swap_grad_hook(bidx, bidx_cuda):
def __grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
# swap blocks if necessary
flux.single_blocks[bidx].to("cpu")
flux.single_blocks[bidx_cuda].to(accelerator.device)
# print(f"Move single block {bidx} to cpu and {bidx_cuda} to device")
return __grad_hook
grad_hook = create_single_swap_grad_hook(block_idx_cpu, block_idx_cuda)
grad_hook = create_swap_grad_hook(
block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting
)
if grad_hook is None:
@@ -547,10 +603,15 @@ def train(args):
num_parameters_per_group = [0] * len(optimizers)
parameter_optimizer_map = {}
double_blocks_to_swap = args.double_blocks_to_swap
single_blocks_to_swap = args.single_blocks_to_swap
blocks_to_swap = args.blocks_to_swap
num_double_blocks = 19 # len(flux.double_blocks)
num_single_blocks = 38 # len(flux.single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2
n = 1 # only asyncronous purpose, no need to increase this number
# n = max(1, os.cpu_count() // 2)
thread_pool = ThreadPoolExecutor(max_workers=n)
futures = {}
for opt_idx, optimizer in enumerate(optimizers):
for param_group in optimizer.param_groups:
@@ -571,18 +632,30 @@ def train(args):
optimizers[i].zero_grad(set_to_none=True)
# swap blocks if necessary
if btype == "double" and double_blocks_to_swap:
if bidx >= num_double_blocks - double_blocks_to_swap:
bidx_cuda = double_blocks_to_swap - (num_double_blocks - bidx)
flux.double_blocks[bidx].to("cpu")
flux.double_blocks[bidx_cuda].to(accelerator.device)
# print(f"Move double block {bidx} to cpu and {bidx_cuda} to device")
elif btype == "single" and single_blocks_to_swap:
if bidx >= num_single_blocks - single_blocks_to_swap:
bidx_cuda = single_blocks_to_swap - (num_single_blocks - bidx)
flux.single_blocks[bidx].to("cpu")
flux.single_blocks[bidx_cuda].to(accelerator.device)
# print(f"Move single block {bidx} to cpu and {bidx_cuda} to device")
if blocks_to_swap and (btype == "double" or (btype == "single" and bidx % 2 == 0)):
unit_idx = bidx if btype == "double" else num_double_blocks + bidx // 2
num_blocks_propagated = num_block_units - unit_idx
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
if swapping:
block_idx_to_cpu = num_block_units - num_blocks_propagated
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
# print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}")
submit_move_blocks(
futures,
thread_pool,
block_idx_to_cpu,
block_idx_to_cuda,
flux.double_blocks,
flux.single_blocks,
accelerator.device,
)
if waiting:
block_idx_to_wait = unit_idx - 1
wait_blocks_move(block_idx_to_wait, futures)
return optimizer_hook
@@ -881,24 +954,26 @@ def setup_parser() -> argparse.ArgumentParser:
help="skip latents validity check / latentsの正当性チェックをスキップする",
)
parser.add_argument(
"--double_blocks_to_swap",
"--blocks_to_swap",
type=int,
default=None,
help="[EXPERIMENTAL] "
"Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes."
"Sets the number of blocks (~640MB) to swap during the forward and backward passes."
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
" / 順伝播および逆伝播中にスワップする'変換ブロック'約640MBの数を設定します。"
" / 順伝播および逆伝播中にスワップするブロック約640MBの数を設定します。"
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度s/itも低下します。",
)
parser.add_argument(
"--double_blocks_to_swap",
type=int,
default=None,
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
)
parser.add_argument(
"--single_blocks_to_swap",
type=int,
default=None,
help="[EXPERIMENTAL] "
"Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes."
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
" / 順伝播および逆伝播中にスワップする'変換ブロック'約320MBの数を設定します。"
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度s/itも低下します。",
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
)
parser.add_argument(
"--cpu_offload_checkpointing",