diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index f8169bdb..61ebfb58 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -1,3 +1,6 @@ +# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用学習コード +# training code for ControlNet-LLLite with passing cond_image to U-Net's forward + import argparse import gc import json @@ -20,6 +23,7 @@ except Exception: pass from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed +import accelerate from diffusers import DDPMScheduler, ControlNetModel from safetensors.torch import load_file from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util @@ -41,7 +45,7 @@ from library.custom_train_functions import ( apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, ) -import networks.control_net_lllite as control_net_lllite +import networks.control_net_lllite_for_train as control_net_lllite_for_train # TODO 他のスクリプトと共通化する @@ -148,9 +152,6 @@ def train(args): ckpt_info, ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - # 学習を準備する if cache_latents: vae.to(accelerator.device, dtype=vae_dtype) @@ -184,22 +185,53 @@ def train(args): ) accelerator.wait_for_everyone() - # prepare ControlNet - network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout) - network.apply_to() + # prepare ControlNet-LLLite + control_net_lllite_for_train.replace_unet_linear_and_conv2d() if args.network_weights is not None: - info = network.load_weights(args.network_weights) - accelerator.print(f"load ControlNet weights from {args.network_weights}: {info}") + accelerator.print(f"initialize U-Net with ControlNet-LLLite") + with accelerate.init_empty_weights(): + unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite() + unet_lllite.to(accelerator.device, dtype=weight_dtype) + + unet_sd = unet.state_dict() + info = unet_lllite.load_lllite_weights(args.network_weights, unet_sd) + accelerator.print(f"load ControlNet-LLLite weights from {args.network_weights}: {info}") + else: + # cosumes large memory, so send to GPU before creating the LLLite model + accelerator.print("sending U-Net to GPU") + unet.to(accelerator.device, dtype=weight_dtype) + unet_sd = unet.state_dict() + + # init LLLite weights + accelerator.print(f"initialize U-Net with ControlNet-LLLite") + + if args.lowram: + with accelerate.init_on_device(accelerator.device): + unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite() + else: + unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite() + unet_lllite.to(weight_dtype) + + info = unet_lllite.load_lllite_weights(None, unet_sd) + accelerator.print(f"init U-Net with ControlNet-LLLite weights: {info}") + del unet_sd, unet + + unet: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite = unet_lllite + del unet_lllite + + unet.apply_lllite(args.cond_emb_dim, args.network_dim, args.network_dropout) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - network.enable_gradient_checkpointing() # may have no effect # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - trainable_params = list(network.prepare_optimizer_params()) + trainable_params = list(unet.prepare_params()) print(f"trainable params count: {len(trainable_params)}") print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") @@ -232,37 +264,32 @@ def train(args): lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする - if args.full_fp16: - assert ( - args.mixed_precision == "fp16" - ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - accelerator.print("enable full fp16 training.") - unet.to(weight_dtype) - network.to(weight_dtype) - elif args.full_bf16: - assert ( - args.mixed_precision == "bf16" - ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" - accelerator.print("enable full bf16 training.") - unet.to(weight_dtype) - network.to(weight_dtype) + # if args.full_fp16: + # assert ( + # args.mixed_precision == "fp16" + # ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + # accelerator.print("enable full fp16 training.") + # unet.to(weight_dtype) + # elif args.full_bf16: + # assert ( + # args.mixed_precision == "bf16" + # ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + # accelerator.print("enable full bf16 training.") + # unet.to(weight_dtype) + + unet.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, network, optimizer, train_dataloader, lr_scheduler - ) - network: control_net_lllite.ControlNetLLLite + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) # transform DDP after prepare (train_network here only) - unet, network = train_util.transform_models_if_DDP([unet, network]) + unet = train_util.transform_models_if_DDP([unet])[0] if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる else: unet.eval() - network.prepare_grad_etc() - # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 @@ -328,7 +355,13 @@ def train(args): del train_dataset_group # function for saving/removing - def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): + def save_model( + ckpt_name, + unwrapped_nw: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite, + steps, + epoch_no, + force_sync_upload=False, + ): os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, ckpt_name) @@ -336,7 +369,7 @@ def train(args): sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite" - unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata) + unwrapped_nw.save_lllite_weights(ckpt_file, save_dtype, sai_metadata) if args.huggingface_repo_id is not None: huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) @@ -351,11 +384,9 @@ def train(args): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 - network.on_epoch_start() # train() - for step, batch in enumerate(train_dataloader): current_step.value = global_step - with accelerator.accumulate(network): + with accelerator.accumulate(unet): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) @@ -412,10 +443,9 @@ def train(args): with accelerator.autocast(): # conditioning imageをControlNetに渡す / pass conditioning image to ControlNet # 内部でcond_embに変換される / it will be converted to cond_emb inside - network.set_cond_image(controlnet_image) # それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values - noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image) if args.v_parameterization: # v-parameterization training @@ -440,7 +470,7 @@ def train(args): accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = network.get_trainable_params() + params_to_clip = unet.get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -459,7 +489,7 @@ def train(args): accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch) + save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch) if args.save_state: train_util.save_and_remove_state_stepwise(args, accelerator, global_step) @@ -498,7 +528,7 @@ def train(args): saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch + 1) remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) if remove_epoch_no is not None: @@ -513,7 +543,7 @@ def train(args): # end of epoch if is_main_process: - network = accelerator.unwrap_model(network) + unet = accelerator.unwrap_model(unet) accelerator.end_training() @@ -522,7 +552,7 @@ def train(args): if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) + save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True) print("model saved.") diff --git a/sdxl_train_control_net_lllite_alt.py b/sdxl_train_control_net_lllite_old.py similarity index 87% rename from sdxl_train_control_net_lllite_alt.py rename to sdxl_train_control_net_lllite_old.py index 61ebfb58..f8169bdb 100644 --- a/sdxl_train_control_net_lllite_alt.py +++ b/sdxl_train_control_net_lllite_old.py @@ -1,6 +1,3 @@ -# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用学習コード -# training code for ControlNet-LLLite with passing cond_image to U-Net's forward - import argparse import gc import json @@ -23,7 +20,6 @@ except Exception: pass from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed -import accelerate from diffusers import DDPMScheduler, ControlNetModel from safetensors.torch import load_file from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util @@ -45,7 +41,7 @@ from library.custom_train_functions import ( apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, ) -import networks.control_net_lllite_for_train as control_net_lllite_for_train +import networks.control_net_lllite as control_net_lllite # TODO 他のスクリプトと共通化する @@ -152,6 +148,9 @@ def train(args): ckpt_info, ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + # 学習を準備する if cache_latents: vae.to(accelerator.device, dtype=vae_dtype) @@ -185,53 +184,22 @@ def train(args): ) accelerator.wait_for_everyone() - # prepare ControlNet-LLLite - control_net_lllite_for_train.replace_unet_linear_and_conv2d() + # prepare ControlNet + network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout) + network.apply_to() if args.network_weights is not None: - accelerator.print(f"initialize U-Net with ControlNet-LLLite") - with accelerate.init_empty_weights(): - unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite() - unet_lllite.to(accelerator.device, dtype=weight_dtype) - - unet_sd = unet.state_dict() - info = unet_lllite.load_lllite_weights(args.network_weights, unet_sd) - accelerator.print(f"load ControlNet-LLLite weights from {args.network_weights}: {info}") - else: - # cosumes large memory, so send to GPU before creating the LLLite model - accelerator.print("sending U-Net to GPU") - unet.to(accelerator.device, dtype=weight_dtype) - unet_sd = unet.state_dict() - - # init LLLite weights - accelerator.print(f"initialize U-Net with ControlNet-LLLite") - - if args.lowram: - with accelerate.init_on_device(accelerator.device): - unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite() - else: - unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite() - unet_lllite.to(weight_dtype) - - info = unet_lllite.load_lllite_weights(None, unet_sd) - accelerator.print(f"init U-Net with ControlNet-LLLite weights: {info}") - del unet_sd, unet - - unet: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite = unet_lllite - del unet_lllite - - unet.apply_lllite(args.cond_emb_dim, args.network_dim, args.network_dropout) - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + info = network.load_weights(args.network_weights) + accelerator.print(f"load ControlNet weights from {args.network_weights}: {info}") if args.gradient_checkpointing: unet.enable_gradient_checkpointing() + network.enable_gradient_checkpointing() # may have no effect # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - trainable_params = list(unet.prepare_params()) + trainable_params = list(network.prepare_optimizer_params()) print(f"trainable params count: {len(trainable_params)}") print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") @@ -264,32 +232,37 @@ def train(args): lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする - # if args.full_fp16: - # assert ( - # args.mixed_precision == "fp16" - # ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - # accelerator.print("enable full fp16 training.") - # unet.to(weight_dtype) - # elif args.full_bf16: - # assert ( - # args.mixed_precision == "bf16" - # ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" - # accelerator.print("enable full bf16 training.") - # unet.to(weight_dtype) - - unet.to(weight_dtype) + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + unet.to(weight_dtype) + network.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + unet.to(weight_dtype) + network.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, network, optimizer, train_dataloader, lr_scheduler + ) + network: control_net_lllite.ControlNetLLLite # transform DDP after prepare (train_network here only) - unet = train_util.transform_models_if_DDP([unet])[0] + unet, network = train_util.transform_models_if_DDP([unet, network]) if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる else: unet.eval() + network.prepare_grad_etc() + # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 @@ -355,13 +328,7 @@ def train(args): del train_dataset_group # function for saving/removing - def save_model( - ckpt_name, - unwrapped_nw: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite, - steps, - epoch_no, - force_sync_upload=False, - ): + def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, ckpt_name) @@ -369,7 +336,7 @@ def train(args): sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite" - unwrapped_nw.save_lllite_weights(ckpt_file, save_dtype, sai_metadata) + unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata) if args.huggingface_repo_id is not None: huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) @@ -384,9 +351,11 @@ def train(args): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + network.on_epoch_start() # train() + for step, batch in enumerate(train_dataloader): current_step.value = global_step - with accelerator.accumulate(unet): + with accelerator.accumulate(network): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) @@ -443,9 +412,10 @@ def train(args): with accelerator.autocast(): # conditioning imageをControlNetに渡す / pass conditioning image to ControlNet # 内部でcond_embに変換される / it will be converted to cond_emb inside + network.set_cond_image(controlnet_image) # それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values - noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image) + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) if args.v_parameterization: # v-parameterization training @@ -470,7 +440,7 @@ def train(args): accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = unet.get_trainable_params() + params_to_clip = network.get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -489,7 +459,7 @@ def train(args): accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch) if args.save_state: train_util.save_and_remove_state_stepwise(args, accelerator, global_step) @@ -528,7 +498,7 @@ def train(args): saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1) remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) if remove_epoch_no is not None: @@ -543,7 +513,7 @@ def train(args): # end of epoch if is_main_process: - unet = accelerator.unwrap_model(unet) + network = accelerator.unwrap_model(network) accelerator.end_training() @@ -552,7 +522,7 @@ def train(args): if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True) + save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) print("model saved.")