mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
feat: add block swap for FLUX.1/SD3 LoRA training
This commit is contained in:
@@ -78,6 +78,10 @@ def train(args):
|
||||
)
|
||||
args.gradient_checkpointing = True
|
||||
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
@@ -518,47 +522,6 @@ def train(args):
|
||||
parameter_optimizer_map[parameter] = opt_idx
|
||||
num_parameters_per_group[opt_idx] += 1
|
||||
|
||||
# add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook
|
||||
if False: # is_swapping_blocks:
|
||||
import library.custom_offloading_utils as custom_offloading_utils
|
||||
|
||||
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
|
||||
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
|
||||
double_blocks_to_swap = args.blocks_to_swap // 2
|
||||
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2
|
||||
|
||||
offloader_double = custom_offloading_utils.TrainOffloader(num_double_blocks, double_blocks_to_swap, accelerator.device)
|
||||
offloader_single = custom_offloading_utils.TrainOffloader(num_single_blocks, single_blocks_to_swap, accelerator.device)
|
||||
|
||||
param_name_pairs = []
|
||||
if not args.blockwise_fused_optimizers:
|
||||
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
||||
param_name_pairs.extend(zip(param_group["params"], param_name_group))
|
||||
else:
|
||||
# named_parameters is a list of (name, parameter) pairs
|
||||
param_name_pairs.extend([(p, n) for n, p in flux.named_parameters()])
|
||||
|
||||
for parameter, param_name in param_name_pairs:
|
||||
if not parameter.requires_grad:
|
||||
continue
|
||||
|
||||
is_double = param_name.startswith("double_blocks")
|
||||
is_single = param_name.startswith("single_blocks")
|
||||
if not is_double and not is_single:
|
||||
continue
|
||||
|
||||
block_index = int(param_name.split(".")[1])
|
||||
if is_double:
|
||||
blocks = flux.double_blocks
|
||||
offloader = offloader_double
|
||||
else:
|
||||
blocks = flux.single_blocks
|
||||
offloader = offloader_single
|
||||
|
||||
grad_hook = offloader.create_grad_hook(blocks, block_index)
|
||||
if grad_hook is not None:
|
||||
parameter.register_post_accumulate_grad_hook(grad_hook)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
@@ -827,6 +790,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
add_custom_train_arguments(parser) # TODO remove this from here
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
flux_train_utils.add_flux_train_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -851,16 +815,6 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--blocks_to_swap",
|
||||
type=int,
|
||||
default=None,
|
||||
help="[EXPERIMENTAL] "
|
||||
"Sets the number of blocks (~640MB) to swap during the forward and backward passes."
|
||||
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
|
||||
" / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。"
|
||||
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--double_blocks_to_swap",
|
||||
type=int,
|
||||
|
||||
Reference in New Issue
Block a user