feat: add block swap for FLUX.1/SD3 LoRA training

This commit is contained in:
Kohya S
2024-11-12 21:39:13 +09:00
parent 17cf249d76
commit 2cb7a6db02
14 changed files with 288 additions and 629 deletions

View File

@@ -78,6 +78,10 @@ def train(args):
)
args.gradient_checkpointing = True
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
@@ -518,47 +522,6 @@ def train(args):
parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1
# add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook
if False: # is_swapping_blocks:
import library.custom_offloading_utils as custom_offloading_utils
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
double_blocks_to_swap = args.blocks_to_swap // 2
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2
offloader_double = custom_offloading_utils.TrainOffloader(num_double_blocks, double_blocks_to_swap, accelerator.device)
offloader_single = custom_offloading_utils.TrainOffloader(num_single_blocks, single_blocks_to_swap, accelerator.device)
param_name_pairs = []
if not args.blockwise_fused_optimizers:
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
param_name_pairs.extend(zip(param_group["params"], param_name_group))
else:
# named_parameters is a list of (name, parameter) pairs
param_name_pairs.extend([(p, n) for n, p in flux.named_parameters()])
for parameter, param_name in param_name_pairs:
if not parameter.requires_grad:
continue
is_double = param_name.startswith("double_blocks")
is_single = param_name.startswith("single_blocks")
if not is_double and not is_single:
continue
block_index = int(param_name.split(".")[1])
if is_double:
blocks = flux.double_blocks
offloader = offloader_double
else:
blocks = flux.single_blocks
offloader = offloader_single
grad_hook = offloader.create_grad_hook(blocks, block_index)
if grad_hook is not None:
parameter.register_post_accumulate_grad_hook(grad_hook)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -827,6 +790,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
add_custom_train_arguments(parser) # TODO remove this from here
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)
parser.add_argument(
@@ -851,16 +815,6 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
)
parser.add_argument(
"--blocks_to_swap",
type=int,
default=None,
help="[EXPERIMENTAL] "
"Sets the number of blocks (~640MB) to swap during the forward and backward passes."
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
" / 順伝播および逆伝播中にスワップするブロック約640MBの数を設定します。"
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度s/itも低下します。",
)
parser.add_argument(
"--double_blocks_to_swap",
type=int,