add memory efficient training for FLUX.1

This commit is contained in:
Kohya S
2024-08-18 16:54:18 +09:00
parent 25f77f6ef0
commit ef535ec6bb
3 changed files with 348 additions and 73 deletions

View File

@@ -1,5 +1,15 @@
# training with captions
# Swap blocks between CPU and GPU:
# This implementation is inspired by and based on the work of 2kpr.
# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading.
# The original idea has been adapted and extended to fit the current project's needs.
# Key features:
# - CPU offloading during forward and backward passes
# - Use of fused optimizer and grad_hook for efficient gradient processing
# - Per-block fused optimizer instances
import argparse
import copy
import math
@@ -54,6 +64,12 @@ def train(args):
)
args.cache_text_encoder_outputs = True
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
logger.warning(
"cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります"
)
args.gradient_checkpointing = True
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
@@ -232,16 +248,25 @@ def train(args):
# now we can delete Text Encoders to free memory
clip_l = None
t5xxl = None
clean_memory_on_device(accelerator.device)
# load FLUX
# if we load to cpu, flux.to(fp8) takes a long time
flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
if args.gradient_checkpointing:
flux.enable_gradient_checkpointing()
flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing)
flux.requires_grad_(True)
if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you!
logger.info(
f"enable block swap: double_blocks_to_swap={args.double_blocks_to_swap}, single_blocks_to_swap={args.single_blocks_to_swap}"
)
flux.enable_block_swap(args.double_blocks_to_swap, args.single_blocks_to_swap)
if not cache_latents:
# load VAE here if not cached
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu")
@@ -265,40 +290,43 @@ def train(args):
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
if args.fused_optimizer_groups:
if args.blockwise_fused_optimizers:
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters.
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
# This balances memory usage and management complexity.
# calculate total number of parameters
n_total_params = sum(len(params["params"]) for params in params_to_optimize)
params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups)
# split params into groups, keeping the learning rate the same for all params in a group
# this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders)
# split params into groups. currently different learning rates are not supported
grouped_params = []
param_group = []
param_group_lr = -1
param_group = {}
for group in params_to_optimize:
lr = group["lr"]
for p in group["params"]:
# if the learning rate is different for different params, start a new group
if lr != param_group_lr:
if param_group:
grouped_params.append({"params": param_group, "lr": param_group_lr})
param_group = []
param_group_lr = lr
named_parameters = list(flux.named_parameters())
assert len(named_parameters) == len(group["params"]), "number of parameters does not match"
for p, np in zip(group["params"], named_parameters):
# determine target layer and block index for each parameter
block_type = "other" # double, single or other
if np[0].startswith("double_blocks"):
block_idx = int(np[0].split(".")[1])
block_type = "double"
elif np[0].startswith("single_blocks"):
block_idx = int(np[0].split(".")[1])
block_type = "single"
else:
block_idx = -1
param_group.append(p)
param_group_key = (block_type, block_idx)
if param_group_key not in param_group:
param_group[param_group_key] = []
param_group[param_group_key].append(p)
# if the group has enough parameters, start a new group
if len(param_group) == params_per_group:
grouped_params.append({"params": param_group, "lr": param_group_lr})
param_group = []
param_group_lr = -1
block_types_and_indices = []
for param_group_key, param_group in param_group.items():
block_types_and_indices.append(param_group_key)
grouped_params.append({"params": param_group, "lr": args.learning_rate})
if param_group:
grouped_params.append({"params": param_group, "lr": param_group_lr})
num_params = 0
for p in param_group:
num_params += p.numel()
accelerator.print(f"block {param_group_key}: {num_params} parameters")
# prepare optimizers for each group
optimizers = []
@@ -307,7 +335,7 @@ def train(args):
optimizers.append(optimizer)
optimizer = optimizers[0] # avoid error in the following code
logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups")
logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
@@ -341,7 +369,7 @@ def train(args):
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
if args.fused_optimizer_groups:
if args.blockwise_fused_optimizers:
# prepare lr schedulers for each optimizer
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
lr_scheduler = lr_schedulers[0] # avoid error in the following code
@@ -414,7 +442,7 @@ def train(args):
parameter.register_post_accumulate_grad_hook(__grad_hook)
elif args.fused_optimizer_groups:
elif args.blockwise_fused_optimizers:
# prepare for additional optimizers and lr schedulers
for i in range(1, len(optimizers)):
optimizers[i] = accelerator.prepare(optimizers[i])
@@ -429,22 +457,46 @@ def train(args):
num_parameters_per_group = [0] * len(optimizers)
parameter_optimizer_map = {}
double_blocks_to_swap = args.double_blocks_to_swap
single_blocks_to_swap = args.single_blocks_to_swap
num_double_blocks = len(flux.double_blocks)
num_single_blocks = len(flux.single_blocks)
for opt_idx, optimizer in enumerate(optimizers):
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
block_type, block_idx = block_types_and_indices[opt_idx]
def optimizer_hook(parameter: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
def create_optimizer_hook(btype, bidx):
def optimizer_hook(parameter: torch.Tensor):
# print(f"optimizer_hook: {btype}, {bidx}")
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
i = parameter_optimizer_map[parameter]
optimizer_hooked_count[i] += 1
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
optimizers[i].step()
optimizers[i].zero_grad(set_to_none=True)
i = parameter_optimizer_map[parameter]
optimizer_hooked_count[i] += 1
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
optimizers[i].step()
optimizers[i].zero_grad(set_to_none=True)
parameter.register_post_accumulate_grad_hook(optimizer_hook)
# swap blocks if necessary
if btype == "double" and double_blocks_to_swap:
if bidx >= num_double_blocks - double_blocks_to_swap:
bidx_cuda = double_blocks_to_swap - (num_double_blocks - bidx)
flux.double_blocks[bidx].to("cpu")
flux.double_blocks[bidx_cuda].to(accelerator.device)
# print(f"Move double block {bidx} to cpu and {bidx_cuda} to device")
elif btype == "single" and single_blocks_to_swap:
if bidx >= num_single_blocks - single_blocks_to_swap:
bidx_cuda = single_blocks_to_swap - (num_single_blocks - bidx)
flux.single_blocks[bidx].to("cpu")
flux.single_blocks[bidx_cuda].to(accelerator.device)
# print(f"Move single block {bidx} to cpu and {bidx_cuda} to device")
return optimizer_hook
parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx))
parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1
@@ -487,6 +539,9 @@ def train(args):
init_kwargs=init_kwargs,
)
if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None:
flux.prepare_block_swap_before_forward()
# For --sample_at_first
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
@@ -502,7 +557,7 @@ def train(args):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
if args.fused_optimizer_groups:
if args.blockwise_fused_optimizers:
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
with accelerator.accumulate(*training_models):
@@ -591,7 +646,7 @@ def train(args):
# backward
accelerator.backward(loss)
if not (args.fused_backward_pass or args.fused_optimizer_groups):
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
@@ -604,7 +659,7 @@ def train(args):
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if args.fused_optimizer_groups:
if args.blockwise_fused_optimizers:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()
@@ -614,7 +669,7 @@ def train(args):
global_step += 1
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
)
# 指定ステップごとにモデルを保存
@@ -673,8 +728,6 @@ def train(args):
is_main_process = accelerator.is_main_process
# if is_main_process:
flux = accelerator.unwrap_model(flux)
clip_l = accelerator.unwrap_model(clip_l)
t5xxl = accelerator.unwrap_model(t5xxl)
accelerator.end_training()
@@ -707,13 +760,43 @@ def setup_parser() -> argparse.ArgumentParser:
"--fused_optimizer_groups",
type=int,
default=None,
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます",
)
parser.add_argument(
"--blockwise_fused_optimizers",
action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
)
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
help="skip latents validity check / latentsの正当性チェックをスキップする",
)
parser.add_argument(
"--double_blocks_to_swap",
type=int,
default=None,
help="[EXPERIMENTAL] "
"Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes."
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
" / 順伝播および逆伝播中にスワップする'変換ブロック'約640MBの数を設定します。"
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度s/itも低下します。",
)
parser.add_argument(
"--single_blocks_to_swap",
type=int,
default=None,
help="[EXPERIMENTAL] "
"Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes."
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
" / 順伝播および逆伝播中にスワップする'変換ブロック'約320MBの数を設定します。"
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度s/itも低下します。",
)
parser.add_argument(
"--cpu_offload_checkpointing",
action="store_true",
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
)
return parser