mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
new block swap for FLUX.1 fine tuning
This commit is contained in:
241
flux_train.py
241
flux_train.py
@@ -11,10 +11,12 @@
|
||||
# - Per-block fused optimizer instances
|
||||
|
||||
import argparse
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
import time
|
||||
from typing import List
|
||||
import toml
|
||||
|
||||
@@ -265,14 +267,30 @@ def train(args):
|
||||
|
||||
flux.requires_grad_(True)
|
||||
|
||||
is_swapping_blocks = args.double_blocks_to_swap or args.single_blocks_to_swap
|
||||
# block swap
|
||||
|
||||
# backward compatibility
|
||||
if args.blocks_to_swap is None:
|
||||
blocks_to_swap = args.double_blocks_to_swap or 0
|
||||
if args.single_blocks_to_swap is not None:
|
||||
blocks_to_swap += args.single_blocks_to_swap // 2
|
||||
if blocks_to_swap > 0:
|
||||
logger.warning(
|
||||
"double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead."
|
||||
" / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。"
|
||||
)
|
||||
logger.info(
|
||||
f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}."
|
||||
)
|
||||
args.blocks_to_swap = blocks_to_swap
|
||||
del blocks_to_swap
|
||||
|
||||
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
if is_swapping_blocks:
|
||||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||
# This idea is based on 2kpr's great work. Thank you!
|
||||
logger.info(
|
||||
f"enable block swap: double_blocks_to_swap={args.double_blocks_to_swap}, single_blocks_to_swap={args.single_blocks_to_swap}"
|
||||
)
|
||||
flux.enable_block_swap(args.double_blocks_to_swap, args.single_blocks_to_swap)
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
flux.enable_block_swap(args.blocks_to_swap)
|
||||
|
||||
if not cache_latents:
|
||||
# load VAE here if not cached
|
||||
@@ -443,82 +461,120 @@ def train(args):
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# memory efficient block swapping
|
||||
|
||||
def get_block_unit(dbl_blocks, sgl_blocks, index: int):
|
||||
if index < len(dbl_blocks):
|
||||
return (dbl_blocks[index],)
|
||||
else:
|
||||
index -= len(dbl_blocks)
|
||||
index *= 2
|
||||
return (sgl_blocks[index], sgl_blocks[index + 1])
|
||||
|
||||
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device):
|
||||
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc):
|
||||
# print(f"Backward: Move block {bidx_to_cpu} to CPU")
|
||||
for block in blocks_to_cpu:
|
||||
block = block.to("cpu", non_blocking=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# print(f"Backward: Move block {bidx_to_cuda} to CUDA")
|
||||
for block in blocks_to_cuda:
|
||||
block = block.to(dvc, non_blocking=True)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}")
|
||||
return bidx_to_cpu, bidx_to_cuda
|
||||
|
||||
blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu)
|
||||
blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda)
|
||||
|
||||
futures[block_idx_to_cuda] = thread_pool.submit(
|
||||
move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device
|
||||
)
|
||||
|
||||
def wait_blocks_move(block_idx, futures):
|
||||
if block_idx not in futures:
|
||||
return
|
||||
# print(f"Backward: Wait for block {block_idx}")
|
||||
# start_time = time.perf_counter()
|
||||
future = futures.pop(block_idx)
|
||||
future.result()
|
||||
# print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
|
||||
# torch.cuda.synchronize()
|
||||
# print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s")
|
||||
|
||||
if args.fused_backward_pass:
|
||||
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
||||
import library.adafactor_fused
|
||||
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
|
||||
double_blocks_to_swap = args.double_blocks_to_swap
|
||||
single_blocks_to_swap = args.single_blocks_to_swap
|
||||
blocks_to_swap = args.blocks_to_swap
|
||||
num_double_blocks = 19 # len(flux.double_blocks)
|
||||
num_single_blocks = 38 # len(flux.single_blocks)
|
||||
handled_double_block_indices = set()
|
||||
handled_single_block_indices = set()
|
||||
num_block_units = num_double_blocks + num_single_blocks // 2
|
||||
handled_unit_indices = set()
|
||||
|
||||
n = 1 # only asyncronous purpose, no need to increase this number
|
||||
# n = 2
|
||||
# n = max(1, os.cpu_count() // 2)
|
||||
thread_pool = ThreadPoolExecutor(max_workers=n)
|
||||
futures = {}
|
||||
|
||||
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
||||
for parameter, param_name in zip(param_group["params"], param_name_group):
|
||||
if parameter.requires_grad:
|
||||
grad_hook = None
|
||||
|
||||
if double_blocks_to_swap:
|
||||
if param_name.startswith("double_blocks"):
|
||||
if blocks_to_swap:
|
||||
is_double = param_name.startswith("double_blocks")
|
||||
is_single = param_name.startswith("single_blocks")
|
||||
if is_double or is_single:
|
||||
block_idx = int(param_name.split(".")[1])
|
||||
if (
|
||||
block_idx not in handled_double_block_indices
|
||||
and block_idx >= (num_double_blocks - double_blocks_to_swap) - 1
|
||||
and block_idx < num_double_blocks - 1
|
||||
):
|
||||
# swap next (already backpropagated) block
|
||||
handled_double_block_indices.add(block_idx)
|
||||
block_idx_cpu = block_idx + 1
|
||||
block_idx_cuda = double_blocks_to_swap - (num_double_blocks - block_idx_cpu)
|
||||
unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2
|
||||
if unit_idx not in handled_unit_indices:
|
||||
# swap following (already backpropagated) block
|
||||
handled_unit_indices.add(unit_idx)
|
||||
|
||||
# create swap hook
|
||||
def create_double_swap_grad_hook(bidx, bidx_cuda):
|
||||
def __grad_hook(tensor: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, param_group)
|
||||
tensor.grad = None
|
||||
# if n blocks were already backpropagated
|
||||
num_blocks_propagated = num_block_units - unit_idx - 1
|
||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
|
||||
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
|
||||
if swapping or waiting:
|
||||
block_idx_to_cpu = num_block_units - num_blocks_propagated
|
||||
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
|
||||
block_idx_to_wait = unit_idx - 1
|
||||
|
||||
# swap blocks if necessary
|
||||
flux.double_blocks[bidx].to("cpu")
|
||||
flux.double_blocks[bidx_cuda].to(accelerator.device)
|
||||
# print(f"Move double block {bidx} to cpu and {bidx_cuda} to device")
|
||||
# create swap hook
|
||||
def create_swap_grad_hook(
|
||||
bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool
|
||||
):
|
||||
def __grad_hook(tensor: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, param_group)
|
||||
tensor.grad = None
|
||||
|
||||
return __grad_hook
|
||||
# print(f"Backward: {uidx}, {swpng}, {wtng}")
|
||||
if swpng:
|
||||
submit_move_blocks(
|
||||
futures,
|
||||
thread_pool,
|
||||
bidx_to_cpu,
|
||||
bidx_to_cuda,
|
||||
flux.double_blocks,
|
||||
flux.single_blocks,
|
||||
accelerator.device,
|
||||
)
|
||||
if wtng:
|
||||
wait_blocks_move(bidx_to_wait, futures)
|
||||
|
||||
grad_hook = create_double_swap_grad_hook(block_idx_cpu, block_idx_cuda)
|
||||
if single_blocks_to_swap:
|
||||
if param_name.startswith("single_blocks"):
|
||||
block_idx = int(param_name.split(".")[1])
|
||||
if (
|
||||
block_idx not in handled_single_block_indices
|
||||
and block_idx >= (num_single_blocks - single_blocks_to_swap) - 1
|
||||
and block_idx < num_single_blocks - 1
|
||||
):
|
||||
handled_single_block_indices.add(block_idx)
|
||||
block_idx_cpu = block_idx + 1
|
||||
block_idx_cuda = single_blocks_to_swap - (num_single_blocks - block_idx_cpu)
|
||||
# print(param_name, block_idx_cpu, block_idx_cuda)
|
||||
return __grad_hook
|
||||
|
||||
# create swap hook
|
||||
def create_single_swap_grad_hook(bidx, bidx_cuda):
|
||||
def __grad_hook(tensor: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, param_group)
|
||||
tensor.grad = None
|
||||
|
||||
# swap blocks if necessary
|
||||
flux.single_blocks[bidx].to("cpu")
|
||||
flux.single_blocks[bidx_cuda].to(accelerator.device)
|
||||
# print(f"Move single block {bidx} to cpu and {bidx_cuda} to device")
|
||||
|
||||
return __grad_hook
|
||||
|
||||
grad_hook = create_single_swap_grad_hook(block_idx_cpu, block_idx_cuda)
|
||||
grad_hook = create_swap_grad_hook(
|
||||
block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting
|
||||
)
|
||||
|
||||
if grad_hook is None:
|
||||
|
||||
@@ -547,10 +603,15 @@ def train(args):
|
||||
num_parameters_per_group = [0] * len(optimizers)
|
||||
parameter_optimizer_map = {}
|
||||
|
||||
double_blocks_to_swap = args.double_blocks_to_swap
|
||||
single_blocks_to_swap = args.single_blocks_to_swap
|
||||
blocks_to_swap = args.blocks_to_swap
|
||||
num_double_blocks = 19 # len(flux.double_blocks)
|
||||
num_single_blocks = 38 # len(flux.single_blocks)
|
||||
num_block_units = num_double_blocks + num_single_blocks // 2
|
||||
|
||||
n = 1 # only asyncronous purpose, no need to increase this number
|
||||
# n = max(1, os.cpu_count() // 2)
|
||||
thread_pool = ThreadPoolExecutor(max_workers=n)
|
||||
futures = {}
|
||||
|
||||
for opt_idx, optimizer in enumerate(optimizers):
|
||||
for param_group in optimizer.param_groups:
|
||||
@@ -571,18 +632,30 @@ def train(args):
|
||||
optimizers[i].zero_grad(set_to_none=True)
|
||||
|
||||
# swap blocks if necessary
|
||||
if btype == "double" and double_blocks_to_swap:
|
||||
if bidx >= num_double_blocks - double_blocks_to_swap:
|
||||
bidx_cuda = double_blocks_to_swap - (num_double_blocks - bidx)
|
||||
flux.double_blocks[bidx].to("cpu")
|
||||
flux.double_blocks[bidx_cuda].to(accelerator.device)
|
||||
# print(f"Move double block {bidx} to cpu and {bidx_cuda} to device")
|
||||
elif btype == "single" and single_blocks_to_swap:
|
||||
if bidx >= num_single_blocks - single_blocks_to_swap:
|
||||
bidx_cuda = single_blocks_to_swap - (num_single_blocks - bidx)
|
||||
flux.single_blocks[bidx].to("cpu")
|
||||
flux.single_blocks[bidx_cuda].to(accelerator.device)
|
||||
# print(f"Move single block {bidx} to cpu and {bidx_cuda} to device")
|
||||
if blocks_to_swap and (btype == "double" or (btype == "single" and bidx % 2 == 0)):
|
||||
unit_idx = bidx if btype == "double" else num_double_blocks + bidx // 2
|
||||
num_blocks_propagated = num_block_units - unit_idx
|
||||
|
||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
|
||||
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
|
||||
|
||||
if swapping:
|
||||
block_idx_to_cpu = num_block_units - num_blocks_propagated
|
||||
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
|
||||
# print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}")
|
||||
submit_move_blocks(
|
||||
futures,
|
||||
thread_pool,
|
||||
block_idx_to_cpu,
|
||||
block_idx_to_cuda,
|
||||
flux.double_blocks,
|
||||
flux.single_blocks,
|
||||
accelerator.device,
|
||||
)
|
||||
|
||||
if waiting:
|
||||
block_idx_to_wait = unit_idx - 1
|
||||
wait_blocks_move(block_idx_to_wait, futures)
|
||||
|
||||
return optimizer_hook
|
||||
|
||||
@@ -881,24 +954,26 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="skip latents validity check / latentsの正当性チェックをスキップする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--double_blocks_to_swap",
|
||||
"--blocks_to_swap",
|
||||
type=int,
|
||||
default=None,
|
||||
help="[EXPERIMENTAL] "
|
||||
"Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes."
|
||||
"Sets the number of blocks (~640MB) to swap during the forward and backward passes."
|
||||
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
|
||||
" / 順伝播および逆伝播中にスワップする'変換ブロック'(約640MB)の数を設定します。"
|
||||
" / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。"
|
||||
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--double_blocks_to_swap",
|
||||
type=int,
|
||||
default=None,
|
||||
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--single_blocks_to_swap",
|
||||
type=int,
|
||||
default=None,
|
||||
help="[EXPERIMENTAL] "
|
||||
"Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes."
|
||||
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
|
||||
" / 順伝播および逆伝播中にスワップする'変換ブロック'(約320MB)の数を設定します。"
|
||||
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。",
|
||||
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu_offload_checkpointing",
|
||||
|
||||
Reference in New Issue
Block a user