mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
rename train_lllite_alt to train_lllite
This commit is contained in:
@@ -1,3 +1,6 @@
|
|||||||
|
# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用学習コード
|
||||||
|
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
@@ -20,6 +23,7 @@ except Exception:
|
|||||||
pass
|
pass
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
import accelerate
|
||||||
from diffusers import DDPMScheduler, ControlNetModel
|
from diffusers import DDPMScheduler, ControlNetModel
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
|
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
|
||||||
@@ -41,7 +45,7 @@ from library.custom_train_functions import (
|
|||||||
apply_noise_offset,
|
apply_noise_offset,
|
||||||
scale_v_prediction_loss_like_noise_prediction,
|
scale_v_prediction_loss_like_noise_prediction,
|
||||||
)
|
)
|
||||||
import networks.control_net_lllite as control_net_lllite
|
import networks.control_net_lllite_for_train as control_net_lllite_for_train
|
||||||
|
|
||||||
|
|
||||||
# TODO 他のスクリプトと共通化する
|
# TODO 他のスクリプトと共通化する
|
||||||
@@ -148,9 +152,6 @@ def train(args):
|
|||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
|
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
|
||||||
|
|
||||||
# モデルに xformers とか memory efficient attention を組み込む
|
|
||||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
|
||||||
|
|
||||||
# 学習を準備する
|
# 学習を準備する
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
vae.to(accelerator.device, dtype=vae_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
@@ -184,22 +185,53 @@ def train(args):
|
|||||||
)
|
)
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# prepare ControlNet
|
# prepare ControlNet-LLLite
|
||||||
network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout)
|
control_net_lllite_for_train.replace_unet_linear_and_conv2d()
|
||||||
network.apply_to()
|
|
||||||
|
|
||||||
if args.network_weights is not None:
|
if args.network_weights is not None:
|
||||||
info = network.load_weights(args.network_weights)
|
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
|
||||||
accelerator.print(f"load ControlNet weights from {args.network_weights}: {info}")
|
with accelerate.init_empty_weights():
|
||||||
|
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
|
||||||
|
unet_lllite.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
unet_sd = unet.state_dict()
|
||||||
|
info = unet_lllite.load_lllite_weights(args.network_weights, unet_sd)
|
||||||
|
accelerator.print(f"load ControlNet-LLLite weights from {args.network_weights}: {info}")
|
||||||
|
else:
|
||||||
|
# cosumes large memory, so send to GPU before creating the LLLite model
|
||||||
|
accelerator.print("sending U-Net to GPU")
|
||||||
|
unet.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
unet_sd = unet.state_dict()
|
||||||
|
|
||||||
|
# init LLLite weights
|
||||||
|
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
|
||||||
|
|
||||||
|
if args.lowram:
|
||||||
|
with accelerate.init_on_device(accelerator.device):
|
||||||
|
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
|
||||||
|
else:
|
||||||
|
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
|
||||||
|
unet_lllite.to(weight_dtype)
|
||||||
|
|
||||||
|
info = unet_lllite.load_lllite_weights(None, unet_sd)
|
||||||
|
accelerator.print(f"init U-Net with ControlNet-LLLite weights: {info}")
|
||||||
|
del unet_sd, unet
|
||||||
|
|
||||||
|
unet: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite = unet_lllite
|
||||||
|
del unet_lllite
|
||||||
|
|
||||||
|
unet.apply_lllite(args.cond_emb_dim, args.network_dim, args.network_dropout)
|
||||||
|
|
||||||
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
network.enable_gradient_checkpointing() # may have no effect
|
|
||||||
|
|
||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
accelerator.print("prepare optimizer, data loader etc.")
|
accelerator.print("prepare optimizer, data loader etc.")
|
||||||
|
|
||||||
trainable_params = list(network.prepare_optimizer_params())
|
trainable_params = list(unet.prepare_params())
|
||||||
print(f"trainable params count: {len(trainable_params)}")
|
print(f"trainable params count: {len(trainable_params)}")
|
||||||
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
|
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
|
||||||
|
|
||||||
@@ -232,37 +264,32 @@ def train(args):
|
|||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||||
if args.full_fp16:
|
# if args.full_fp16:
|
||||||
assert (
|
# assert (
|
||||||
args.mixed_precision == "fp16"
|
# args.mixed_precision == "fp16"
|
||||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
# ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||||
accelerator.print("enable full fp16 training.")
|
# accelerator.print("enable full fp16 training.")
|
||||||
unet.to(weight_dtype)
|
# unet.to(weight_dtype)
|
||||||
network.to(weight_dtype)
|
# elif args.full_bf16:
|
||||||
elif args.full_bf16:
|
# assert (
|
||||||
assert (
|
# args.mixed_precision == "bf16"
|
||||||
args.mixed_precision == "bf16"
|
# ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||||
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
# accelerator.print("enable full bf16 training.")
|
||||||
accelerator.print("enable full bf16 training.")
|
# unet.to(weight_dtype)
|
||||||
unet.to(weight_dtype)
|
|
||||||
network.to(weight_dtype)
|
unet.to(weight_dtype)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||||
unet, network, optimizer, train_dataloader, lr_scheduler
|
|
||||||
)
|
|
||||||
network: control_net_lllite.ControlNetLLLite
|
|
||||||
|
|
||||||
# transform DDP after prepare (train_network here only)
|
# transform DDP after prepare (train_network here only)
|
||||||
unet, network = train_util.transform_models_if_DDP([unet, network])
|
unet = train_util.transform_models_if_DDP([unet])[0]
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
||||||
else:
|
else:
|
||||||
unet.eval()
|
unet.eval()
|
||||||
|
|
||||||
network.prepare_grad_etc()
|
|
||||||
|
|
||||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||||
@@ -328,7 +355,13 @@ def train(args):
|
|||||||
del train_dataset_group
|
del train_dataset_group
|
||||||
|
|
||||||
# function for saving/removing
|
# function for saving/removing
|
||||||
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
|
def save_model(
|
||||||
|
ckpt_name,
|
||||||
|
unwrapped_nw: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite,
|
||||||
|
steps,
|
||||||
|
epoch_no,
|
||||||
|
force_sync_upload=False,
|
||||||
|
):
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
|
|
||||||
@@ -336,7 +369,7 @@ def train(args):
|
|||||||
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
|
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
|
||||||
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite"
|
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite"
|
||||||
|
|
||||||
unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata)
|
unwrapped_nw.save_lllite_weights(ckpt_file, save_dtype, sai_metadata)
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||||
|
|
||||||
@@ -351,11 +384,9 @@ def train(args):
|
|||||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
|
|
||||||
network.on_epoch_start() # train()
|
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(network):
|
with accelerator.accumulate(unet):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
@@ -412,10 +443,9 @@ def train(args):
|
|||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
# conditioning imageをControlNetに渡す / pass conditioning image to ControlNet
|
# conditioning imageをControlNetに渡す / pass conditioning image to ControlNet
|
||||||
# 内部でcond_embに変換される / it will be converted to cond_emb inside
|
# 内部でcond_embに変換される / it will be converted to cond_emb inside
|
||||||
network.set_cond_image(controlnet_image)
|
|
||||||
|
|
||||||
# それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values
|
# それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values
|
||||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image)
|
||||||
|
|
||||||
if args.v_parameterization:
|
if args.v_parameterization:
|
||||||
# v-parameterization training
|
# v-parameterization training
|
||||||
@@ -440,7 +470,7 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
params_to_clip = network.get_trainable_params()
|
params_to_clip = unet.get_trainable_params()
|
||||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@@ -459,7 +489,7 @@ def train(args):
|
|||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||||
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
|
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch)
|
||||||
|
|
||||||
if args.save_state:
|
if args.save_state:
|
||||||
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||||
@@ -498,7 +528,7 @@ def train(args):
|
|||||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||||
if is_main_process and saving:
|
if is_main_process and saving:
|
||||||
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
||||||
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
|
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch + 1)
|
||||||
|
|
||||||
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
||||||
if remove_epoch_no is not None:
|
if remove_epoch_no is not None:
|
||||||
@@ -513,7 +543,7 @@ def train(args):
|
|||||||
# end of epoch
|
# end of epoch
|
||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
network = accelerator.unwrap_model(network)
|
unet = accelerator.unwrap_model(unet)
|
||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
@@ -522,7 +552,7 @@ def train(args):
|
|||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||||
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
|
save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True)
|
||||||
|
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,3 @@
|
|||||||
# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用学習コード
|
|
||||||
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
@@ -23,7 +20,6 @@ except Exception:
|
|||||||
pass
|
pass
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import accelerate
|
|
||||||
from diffusers import DDPMScheduler, ControlNetModel
|
from diffusers import DDPMScheduler, ControlNetModel
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
|
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
|
||||||
@@ -45,7 +41,7 @@ from library.custom_train_functions import (
|
|||||||
apply_noise_offset,
|
apply_noise_offset,
|
||||||
scale_v_prediction_loss_like_noise_prediction,
|
scale_v_prediction_loss_like_noise_prediction,
|
||||||
)
|
)
|
||||||
import networks.control_net_lllite_for_train as control_net_lllite_for_train
|
import networks.control_net_lllite as control_net_lllite
|
||||||
|
|
||||||
|
|
||||||
# TODO 他のスクリプトと共通化する
|
# TODO 他のスクリプトと共通化する
|
||||||
@@ -152,6 +148,9 @@ def train(args):
|
|||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
|
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
|
||||||
|
|
||||||
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||||
|
|
||||||
# 学習を準備する
|
# 学習を準備する
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
vae.to(accelerator.device, dtype=vae_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
@@ -185,53 +184,22 @@ def train(args):
|
|||||||
)
|
)
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# prepare ControlNet-LLLite
|
# prepare ControlNet
|
||||||
control_net_lllite_for_train.replace_unet_linear_and_conv2d()
|
network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout)
|
||||||
|
network.apply_to()
|
||||||
|
|
||||||
if args.network_weights is not None:
|
if args.network_weights is not None:
|
||||||
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
|
info = network.load_weights(args.network_weights)
|
||||||
with accelerate.init_empty_weights():
|
accelerator.print(f"load ControlNet weights from {args.network_weights}: {info}")
|
||||||
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
|
|
||||||
unet_lllite.to(accelerator.device, dtype=weight_dtype)
|
|
||||||
|
|
||||||
unet_sd = unet.state_dict()
|
|
||||||
info = unet_lllite.load_lllite_weights(args.network_weights, unet_sd)
|
|
||||||
accelerator.print(f"load ControlNet-LLLite weights from {args.network_weights}: {info}")
|
|
||||||
else:
|
|
||||||
# cosumes large memory, so send to GPU before creating the LLLite model
|
|
||||||
accelerator.print("sending U-Net to GPU")
|
|
||||||
unet.to(accelerator.device, dtype=weight_dtype)
|
|
||||||
unet_sd = unet.state_dict()
|
|
||||||
|
|
||||||
# init LLLite weights
|
|
||||||
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
|
|
||||||
|
|
||||||
if args.lowram:
|
|
||||||
with accelerate.init_on_device(accelerator.device):
|
|
||||||
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
|
|
||||||
else:
|
|
||||||
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
|
|
||||||
unet_lllite.to(weight_dtype)
|
|
||||||
|
|
||||||
info = unet_lllite.load_lllite_weights(None, unet_sd)
|
|
||||||
accelerator.print(f"init U-Net with ControlNet-LLLite weights: {info}")
|
|
||||||
del unet_sd, unet
|
|
||||||
|
|
||||||
unet: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite = unet_lllite
|
|
||||||
del unet_lllite
|
|
||||||
|
|
||||||
unet.apply_lllite(args.cond_emb_dim, args.network_dim, args.network_dropout)
|
|
||||||
|
|
||||||
# モデルに xformers とか memory efficient attention を組み込む
|
|
||||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
|
network.enable_gradient_checkpointing() # may have no effect
|
||||||
|
|
||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
accelerator.print("prepare optimizer, data loader etc.")
|
accelerator.print("prepare optimizer, data loader etc.")
|
||||||
|
|
||||||
trainable_params = list(unet.prepare_params())
|
trainable_params = list(network.prepare_optimizer_params())
|
||||||
print(f"trainable params count: {len(trainable_params)}")
|
print(f"trainable params count: {len(trainable_params)}")
|
||||||
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
|
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
|
||||||
|
|
||||||
@@ -264,32 +232,37 @@ def train(args):
|
|||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||||
# if args.full_fp16:
|
if args.full_fp16:
|
||||||
# assert (
|
assert (
|
||||||
# args.mixed_precision == "fp16"
|
args.mixed_precision == "fp16"
|
||||||
# ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||||
# accelerator.print("enable full fp16 training.")
|
accelerator.print("enable full fp16 training.")
|
||||||
# unet.to(weight_dtype)
|
unet.to(weight_dtype)
|
||||||
# elif args.full_bf16:
|
network.to(weight_dtype)
|
||||||
# assert (
|
elif args.full_bf16:
|
||||||
# args.mixed_precision == "bf16"
|
assert (
|
||||||
# ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
args.mixed_precision == "bf16"
|
||||||
# accelerator.print("enable full bf16 training.")
|
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||||
# unet.to(weight_dtype)
|
accelerator.print("enable full bf16 training.")
|
||||||
|
unet.to(weight_dtype)
|
||||||
unet.to(weight_dtype)
|
network.to(weight_dtype)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
unet, network, optimizer, train_dataloader, lr_scheduler
|
||||||
|
)
|
||||||
|
network: control_net_lllite.ControlNetLLLite
|
||||||
|
|
||||||
# transform DDP after prepare (train_network here only)
|
# transform DDP after prepare (train_network here only)
|
||||||
unet = train_util.transform_models_if_DDP([unet])[0]
|
unet, network = train_util.transform_models_if_DDP([unet, network])
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
||||||
else:
|
else:
|
||||||
unet.eval()
|
unet.eval()
|
||||||
|
|
||||||
|
network.prepare_grad_etc()
|
||||||
|
|
||||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||||
@@ -355,13 +328,7 @@ def train(args):
|
|||||||
del train_dataset_group
|
del train_dataset_group
|
||||||
|
|
||||||
# function for saving/removing
|
# function for saving/removing
|
||||||
def save_model(
|
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
|
||||||
ckpt_name,
|
|
||||||
unwrapped_nw: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite,
|
|
||||||
steps,
|
|
||||||
epoch_no,
|
|
||||||
force_sync_upload=False,
|
|
||||||
):
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
|
|
||||||
@@ -369,7 +336,7 @@ def train(args):
|
|||||||
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
|
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
|
||||||
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite"
|
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite"
|
||||||
|
|
||||||
unwrapped_nw.save_lllite_weights(ckpt_file, save_dtype, sai_metadata)
|
unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata)
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||||
|
|
||||||
@@ -384,9 +351,11 @@ def train(args):
|
|||||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
|
|
||||||
|
network.on_epoch_start() # train()
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(unet):
|
with accelerator.accumulate(network):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
@@ -443,9 +412,10 @@ def train(args):
|
|||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
# conditioning imageをControlNetに渡す / pass conditioning image to ControlNet
|
# conditioning imageをControlNetに渡す / pass conditioning image to ControlNet
|
||||||
# 内部でcond_embに変換される / it will be converted to cond_emb inside
|
# 内部でcond_embに変換される / it will be converted to cond_emb inside
|
||||||
|
network.set_cond_image(controlnet_image)
|
||||||
|
|
||||||
# それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values
|
# それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values
|
||||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image)
|
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||||
|
|
||||||
if args.v_parameterization:
|
if args.v_parameterization:
|
||||||
# v-parameterization training
|
# v-parameterization training
|
||||||
@@ -470,7 +440,7 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
params_to_clip = unet.get_trainable_params()
|
params_to_clip = network.get_trainable_params()
|
||||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@@ -489,7 +459,7 @@ def train(args):
|
|||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||||
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch)
|
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
|
||||||
|
|
||||||
if args.save_state:
|
if args.save_state:
|
||||||
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||||
@@ -528,7 +498,7 @@ def train(args):
|
|||||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||||
if is_main_process and saving:
|
if is_main_process and saving:
|
||||||
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
||||||
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch + 1)
|
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
|
||||||
|
|
||||||
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
||||||
if remove_epoch_no is not None:
|
if remove_epoch_no is not None:
|
||||||
@@ -543,7 +513,7 @@ def train(args):
|
|||||||
# end of epoch
|
# end of epoch
|
||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
unet = accelerator.unwrap_model(unet)
|
network = accelerator.unwrap_model(network)
|
||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
@@ -552,7 +522,7 @@ def train(args):
|
|||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||||
save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True)
|
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
|
||||||
|
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
Reference in New Issue
Block a user