mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add memory efficient training for FLUX.1
This commit is contained in:
175
flux_train.py
175
flux_train.py
@@ -1,5 +1,15 @@
|
||||
# training with captions
|
||||
|
||||
# Swap blocks between CPU and GPU:
|
||||
# This implementation is inspired by and based on the work of 2kpr.
|
||||
# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading.
|
||||
# The original idea has been adapted and extended to fit the current project's needs.
|
||||
|
||||
# Key features:
|
||||
# - CPU offloading during forward and backward passes
|
||||
# - Use of fused optimizer and grad_hook for efficient gradient processing
|
||||
# - Per-block fused optimizer instances
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
@@ -54,6 +64,12 @@ def train(args):
|
||||
)
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
|
||||
logger.warning(
|
||||
"cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります"
|
||||
)
|
||||
args.gradient_checkpointing = True
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
@@ -232,16 +248,25 @@ def train(args):
|
||||
# now we can delete Text Encoders to free memory
|
||||
clip_l = None
|
||||
t5xxl = None
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# load FLUX
|
||||
# if we load to cpu, flux.to(fp8) takes a long time
|
||||
flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
flux.enable_gradient_checkpointing()
|
||||
flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing)
|
||||
|
||||
flux.requires_grad_(True)
|
||||
|
||||
if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None:
|
||||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||
# This idea is based on 2kpr's great work. Thank you!
|
||||
logger.info(
|
||||
f"enable block swap: double_blocks_to_swap={args.double_blocks_to_swap}, single_blocks_to_swap={args.single_blocks_to_swap}"
|
||||
)
|
||||
flux.enable_block_swap(args.double_blocks_to_swap, args.single_blocks_to_swap)
|
||||
|
||||
if not cache_latents:
|
||||
# load VAE here if not cached
|
||||
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu")
|
||||
@@ -265,40 +290,43 @@ def train(args):
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
if args.fused_optimizer_groups:
|
||||
if args.blockwise_fused_optimizers:
|
||||
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
|
||||
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters.
|
||||
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
|
||||
# This balances memory usage and management complexity.
|
||||
|
||||
# calculate total number of parameters
|
||||
n_total_params = sum(len(params["params"]) for params in params_to_optimize)
|
||||
params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups)
|
||||
|
||||
# split params into groups, keeping the learning rate the same for all params in a group
|
||||
# this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders)
|
||||
# split params into groups. currently different learning rates are not supported
|
||||
grouped_params = []
|
||||
param_group = []
|
||||
param_group_lr = -1
|
||||
param_group = {}
|
||||
for group in params_to_optimize:
|
||||
lr = group["lr"]
|
||||
for p in group["params"]:
|
||||
# if the learning rate is different for different params, start a new group
|
||||
if lr != param_group_lr:
|
||||
if param_group:
|
||||
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
||||
param_group = []
|
||||
param_group_lr = lr
|
||||
named_parameters = list(flux.named_parameters())
|
||||
assert len(named_parameters) == len(group["params"]), "number of parameters does not match"
|
||||
for p, np in zip(group["params"], named_parameters):
|
||||
# determine target layer and block index for each parameter
|
||||
block_type = "other" # double, single or other
|
||||
if np[0].startswith("double_blocks"):
|
||||
block_idx = int(np[0].split(".")[1])
|
||||
block_type = "double"
|
||||
elif np[0].startswith("single_blocks"):
|
||||
block_idx = int(np[0].split(".")[1])
|
||||
block_type = "single"
|
||||
else:
|
||||
block_idx = -1
|
||||
|
||||
param_group.append(p)
|
||||
param_group_key = (block_type, block_idx)
|
||||
if param_group_key not in param_group:
|
||||
param_group[param_group_key] = []
|
||||
param_group[param_group_key].append(p)
|
||||
|
||||
# if the group has enough parameters, start a new group
|
||||
if len(param_group) == params_per_group:
|
||||
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
||||
param_group = []
|
||||
param_group_lr = -1
|
||||
block_types_and_indices = []
|
||||
for param_group_key, param_group in param_group.items():
|
||||
block_types_and_indices.append(param_group_key)
|
||||
grouped_params.append({"params": param_group, "lr": args.learning_rate})
|
||||
|
||||
if param_group:
|
||||
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
||||
num_params = 0
|
||||
for p in param_group:
|
||||
num_params += p.numel()
|
||||
accelerator.print(f"block {param_group_key}: {num_params} parameters")
|
||||
|
||||
# prepare optimizers for each group
|
||||
optimizers = []
|
||||
@@ -307,7 +335,7 @@ def train(args):
|
||||
optimizers.append(optimizer)
|
||||
optimizer = optimizers[0] # avoid error in the following code
|
||||
|
||||
logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups")
|
||||
logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
|
||||
|
||||
else:
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
@@ -341,7 +369,7 @@ def train(args):
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
if args.fused_optimizer_groups:
|
||||
if args.blockwise_fused_optimizers:
|
||||
# prepare lr schedulers for each optimizer
|
||||
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
|
||||
lr_scheduler = lr_schedulers[0] # avoid error in the following code
|
||||
@@ -414,7 +442,7 @@ def train(args):
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(__grad_hook)
|
||||
|
||||
elif args.fused_optimizer_groups:
|
||||
elif args.blockwise_fused_optimizers:
|
||||
# prepare for additional optimizers and lr schedulers
|
||||
for i in range(1, len(optimizers)):
|
||||
optimizers[i] = accelerator.prepare(optimizers[i])
|
||||
@@ -429,22 +457,46 @@ def train(args):
|
||||
num_parameters_per_group = [0] * len(optimizers)
|
||||
parameter_optimizer_map = {}
|
||||
|
||||
double_blocks_to_swap = args.double_blocks_to_swap
|
||||
single_blocks_to_swap = args.single_blocks_to_swap
|
||||
num_double_blocks = len(flux.double_blocks)
|
||||
num_single_blocks = len(flux.single_blocks)
|
||||
|
||||
for opt_idx, optimizer in enumerate(optimizers):
|
||||
for param_group in optimizer.param_groups:
|
||||
for parameter in param_group["params"]:
|
||||
if parameter.requires_grad:
|
||||
block_type, block_idx = block_types_and_indices[opt_idx]
|
||||
|
||||
def optimizer_hook(parameter: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
|
||||
def create_optimizer_hook(btype, bidx):
|
||||
def optimizer_hook(parameter: torch.Tensor):
|
||||
# print(f"optimizer_hook: {btype}, {bidx}")
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
|
||||
|
||||
i = parameter_optimizer_map[parameter]
|
||||
optimizer_hooked_count[i] += 1
|
||||
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
||||
optimizers[i].step()
|
||||
optimizers[i].zero_grad(set_to_none=True)
|
||||
i = parameter_optimizer_map[parameter]
|
||||
optimizer_hooked_count[i] += 1
|
||||
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
||||
optimizers[i].step()
|
||||
optimizers[i].zero_grad(set_to_none=True)
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(optimizer_hook)
|
||||
# swap blocks if necessary
|
||||
if btype == "double" and double_blocks_to_swap:
|
||||
if bidx >= num_double_blocks - double_blocks_to_swap:
|
||||
bidx_cuda = double_blocks_to_swap - (num_double_blocks - bidx)
|
||||
flux.double_blocks[bidx].to("cpu")
|
||||
flux.double_blocks[bidx_cuda].to(accelerator.device)
|
||||
# print(f"Move double block {bidx} to cpu and {bidx_cuda} to device")
|
||||
elif btype == "single" and single_blocks_to_swap:
|
||||
if bidx >= num_single_blocks - single_blocks_to_swap:
|
||||
bidx_cuda = single_blocks_to_swap - (num_single_blocks - bidx)
|
||||
flux.single_blocks[bidx].to("cpu")
|
||||
flux.single_blocks[bidx_cuda].to(accelerator.device)
|
||||
# print(f"Move single block {bidx} to cpu and {bidx_cuda} to device")
|
||||
|
||||
return optimizer_hook
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx))
|
||||
parameter_optimizer_map[parameter] = opt_idx
|
||||
num_parameters_per_group[opt_idx] += 1
|
||||
|
||||
@@ -487,6 +539,9 @@ def train(args):
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None:
|
||||
flux.prepare_block_swap_before_forward()
|
||||
|
||||
# For --sample_at_first
|
||||
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
|
||||
|
||||
@@ -502,7 +557,7 @@ def train(args):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
|
||||
if args.fused_optimizer_groups:
|
||||
if args.blockwise_fused_optimizers:
|
||||
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
|
||||
|
||||
with accelerator.accumulate(*training_models):
|
||||
@@ -591,7 +646,7 @@ def train(args):
|
||||
# backward
|
||||
accelerator.backward(loss)
|
||||
|
||||
if not (args.fused_backward_pass or args.fused_optimizer_groups):
|
||||
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
@@ -604,7 +659,7 @@ def train(args):
|
||||
else:
|
||||
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
||||
lr_scheduler.step()
|
||||
if args.fused_optimizer_groups:
|
||||
if args.blockwise_fused_optimizers:
|
||||
for i in range(1, len(optimizers)):
|
||||
lr_schedulers[i].step()
|
||||
|
||||
@@ -614,7 +669,7 @@ def train(args):
|
||||
global_step += 1
|
||||
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
||||
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
@@ -673,8 +728,6 @@ def train(args):
|
||||
is_main_process = accelerator.is_main_process
|
||||
# if is_main_process:
|
||||
flux = accelerator.unwrap_model(flux)
|
||||
clip_l = accelerator.unwrap_model(clip_l)
|
||||
t5xxl = accelerator.unwrap_model(t5xxl)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
@@ -707,13 +760,43 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--fused_optimizer_groups",
|
||||
type=int,
|
||||
default=None,
|
||||
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
|
||||
help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--blockwise_fused_optimizers",
|
||||
action="store_true",
|
||||
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_latents_validity_check",
|
||||
action="store_true",
|
||||
help="skip latents validity check / latentsの正当性チェックをスキップする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--double_blocks_to_swap",
|
||||
type=int,
|
||||
default=None,
|
||||
help="[EXPERIMENTAL] "
|
||||
"Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes."
|
||||
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
|
||||
" / 順伝播および逆伝播中にスワップする'変換ブロック'(約640MB)の数を設定します。"
|
||||
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--single_blocks_to_swap",
|
||||
type=int,
|
||||
default=None,
|
||||
help="[EXPERIMENTAL] "
|
||||
"Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes."
|
||||
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
|
||||
" / 順伝播および逆伝播中にスワップする'変換ブロック'(約320MB)の数を設定します。"
|
||||
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu_offload_checkpointing",
|
||||
action="store_true",
|
||||
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user