feat: support block_to_swap for FLUX.1 ControlNet training

This commit is contained in:
Kohya S
2024-12-03 08:43:26 +09:00
parent e369b9a252
commit 8b36d907d8
2 changed files with 45 additions and 14 deletions

View File

@@ -119,9 +119,7 @@ def train(args):
"datasets": [
{
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
args.train_data_dir,
args.conditioning_data_dir,
args.caption_extension
args.train_data_dir, args.conditioning_data_dir, args.caption_extension
)
}
]
@@ -263,13 +261,17 @@ def train(args):
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
)
flux.requires_grad_(False)
flux.to(accelerator.device)
# load controlnet
controlnet = flux_utils.load_controlnet(args.controlnet, is_schnell, torch.float32, accelerator.device, args.disable_mmap_load_safetensors)
controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype
controlnet = flux_utils.load_controlnet(
args.controlnet, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors
)
controlnet.train()
if args.gradient_checkpointing:
if not args.deepspeed:
flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
# block swap
@@ -296,7 +298,11 @@ def train(args):
# This idea is based on 2kpr's great work. Thank you!
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device)
flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
# ControlNet only has two blocks, so we can keep it on GPU
# controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device)
else:
flux.to(accelerator.device)
if not cache_latents:
# load VAE here if not cached
@@ -455,9 +461,7 @@ def train(args):
else:
# accelerator does some magic
# if we doesn't swap blocks, we can move the model to device
controlnet = accelerator.prepare(controlnet, device_placement=[not is_swapping_blocks])
if is_swapping_blocks:
accelerator.unwrap_model(controlnet).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
controlnet = accelerator.prepare(controlnet) # , device_placement=[not is_swapping_blocks])
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
@@ -564,11 +568,13 @@ def train(args):
)
if is_swapping_blocks:
accelerator.unwrap_model(controlnet).prepare_block_swap_before_forward()
flux.prepare_block_swap_before_forward()
# For --sample_at_first
optimizer_eval_fn()
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet)
flux_train_utils.sample_images(
accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
)
optimizer_train_fn()
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
@@ -629,7 +635,11 @@ def train(args):
# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device).to(weight_dtype)
img_ids = (
flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width)
.to(device=accelerator.device)
.to(weight_dtype)
)
# get guidance: ensure args.guidance_scale is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device, dtype=weight_dtype)
@@ -638,7 +648,7 @@ def train(args):
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None
with accelerator.autocast():
block_samples, block_single_samples = controlnet(
img=packed_noisy_model_input,
@@ -715,7 +725,15 @@ def train(args):
optimizer_eval_fn()
flux_train_utils.sample_images(
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
accelerator,
args,
None,
global_step,
flux,
ae,
[clip_l, t5xxl],
sample_prompts_te_outputs,
controlnet=controlnet,
)
# 指定ステップごとにモデルを保存