From 8b36d907d8635dca64224574b5cb15013e00809d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 3 Dec 2024 08:43:26 +0900 Subject: [PATCH] feat: support block_to_swap for FLUX.1 ControlNet training --- README.md | 13 +++++++++++ flux_train_control_net.py | 46 +++++++++++++++++++++++++++------------ 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 6a5cdd34..f0272519 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,11 @@ The command to install PyTorch is as follows: ### Recent Updates + +Dec 3, 2024: + +-`--blocks_to_swap` now works in FLUX.1 ControlNet training. Sample commands for 24GB VRAM and 16GB VRAM are added [here](#flux1-controlnet-training). + Dec 2, 2024: - FLUX.1 ControlNet training is supported. PR [#1813](https://github.com/kohya-ss/sd-scripts/pull/1813). Thanks to minux302! See PR and [here](#flux1-controlnet-training) for details. @@ -276,6 +281,14 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_tr --timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --deepspeed ``` +For 24GB VRAM GPUs, you can train with 16 blocks swapped and caching latents and text encoder outputs with the batch size of 1. Remove `--deepspeed` . Sample command is below. Not fully tested. +``` + --blocks_to_swap 16 --cache_latents_to_disk --cache_text_encoder_outputs_to_disk +``` + +The training can be done with 16GB VRAM GPUs with around 30 blocks swapped. + +`--gradient_accumulation_steps` is also available. The default value is 1 (no accumulation), but according to the original PR, 8 is used. ### FLUX.1 OFT training diff --git a/flux_train_control_net.py b/flux_train_control_net.py index bb27c35e..5548fd99 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -119,9 +119,7 @@ def train(args): "datasets": [ { "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( - args.train_data_dir, - args.conditioning_data_dir, - args.caption_extension + args.train_data_dir, args.conditioning_data_dir, args.caption_extension ) } ] @@ -263,13 +261,17 @@ def train(args): args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors ) flux.requires_grad_(False) - flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet(args.controlnet, is_schnell, torch.float32, accelerator.device, args.disable_mmap_load_safetensors) + controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype + controlnet = flux_utils.load_controlnet( + args.controlnet, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors + ) controlnet.train() if args.gradient_checkpointing: + if not args.deepspeed: + flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) # block swap @@ -296,7 +298,11 @@ def train(args): # This idea is based on 2kpr's great work. Thank you! logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") flux.enable_block_swap(args.blocks_to_swap, accelerator.device) - controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) + flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + # ControlNet only has two blocks, so we can keep it on GPU + # controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) + else: + flux.to(accelerator.device) if not cache_latents: # load VAE here if not cached @@ -455,9 +461,7 @@ def train(args): else: # accelerator does some magic # if we doesn't swap blocks, we can move the model to device - controlnet = accelerator.prepare(controlnet, device_placement=[not is_swapping_blocks]) - if is_swapping_blocks: - accelerator.unwrap_model(controlnet).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + controlnet = accelerator.prepare(controlnet) # , device_placement=[not is_swapping_blocks]) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -564,11 +568,13 @@ def train(args): ) if is_swapping_blocks: - accelerator.unwrap_model(controlnet).prepare_block_swap_before_forward() + flux.prepare_block_swap_before_forward() # For --sample_at_first optimizer_eval_fn() - flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet) + flux_train_utils.sample_images( + accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet + ) optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb @@ -629,7 +635,11 @@ def train(args): # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device).to(weight_dtype) + img_ids = ( + flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width) + .to(device=accelerator.device) + .to(weight_dtype) + ) # get guidance: ensure args.guidance_scale is float guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device, dtype=weight_dtype) @@ -638,7 +648,7 @@ def train(args): l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds if not args.apply_t5_attn_mask: t5_attn_mask = None - + with accelerator.autocast(): block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, @@ -715,7 +725,15 @@ def train(args): optimizer_eval_fn() flux_train_utils.sample_images( - accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet + accelerator, + args, + None, + global_step, + flux, + ae, + [clip_l, t5xxl], + sample_prompts_te_outputs, + controlnet=controlnet, ) # 指定ステップごとにモデルを保存