mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 16:22:28 +00:00
add comments
This commit is contained in:
@@ -5,12 +5,25 @@ from networks.lora import LoRAModule, LoRANetwork
|
||||
from library import sdxl_original_unet
|
||||
|
||||
|
||||
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
||||
SKIP_INPUT_BLOCKS = False
|
||||
|
||||
# output_blocksに適用するかどうか / if True, output_blocks are not applied
|
||||
SKIP_OUTPUT_BLOCKS = True
|
||||
|
||||
# conv2dに適用するかどうか / if True, conv2d are not applied
|
||||
SKIP_CONV2D = False
|
||||
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored
|
||||
ATTN1_ETC_ONLY = False # True
|
||||
TRANSFORMER_MAX_BLOCK_INDEX = None # 3 # None # 2 # None for all blocks
|
||||
|
||||
# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
|
||||
# if True, only transformer_blocks are applied, and ResBlocks are not applied
|
||||
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
|
||||
|
||||
# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
|
||||
ATTN1_ETC_ONLY = False # True
|
||||
|
||||
# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
|
||||
# max index of transformer_blocks. if None, apply to all transformer_blocks
|
||||
TRANSFORMER_MAX_BLOCK_INDEX = None
|
||||
|
||||
|
||||
class LoRAModuleControlNet(LoRAModule):
|
||||
@@ -19,6 +32,16 @@ class LoRAModuleControlNet(LoRAModule):
|
||||
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
||||
self.cond_emb_dim = cond_emb_dim
|
||||
|
||||
# conditioning1は、conditioning image embeddingを、各LoRA的モジュールでさらに学習する。ここはtimestepごとに呼ばれない
|
||||
# それぞれのモジュールで異なる表現を学習することを期待している
|
||||
# conditioning1 learns conditioning image embedding in each LoRA-like module. this is not called for each timestep
|
||||
# we expect to learn different representations in each module
|
||||
|
||||
# conditioning2は、conditioning1の出力とLoRAの出力を結合し、LoRAの出力に加算する。timestepごとに呼ばれる
|
||||
# conditioning image embeddingとU-Netの出力を合わせて学ぶことで、conditioningに応じたU-Netの調整を行う
|
||||
# conditioning2 combines the output of conditioning1 and the output of LoRA, and adds it to the output of LoRA. this is called for each timestep
|
||||
# by learning the output of conditioning image embedding and U-Net together, we adjust U-Net according to conditioning
|
||||
|
||||
if self.is_conv2d:
|
||||
self.conditioning1 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0),
|
||||
@@ -45,16 +68,26 @@ class LoRAModuleControlNet(LoRAModule):
|
||||
torch.nn.Linear(cond_emb_dim, lora_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
# Zero-Convにするならコメントを外す / uncomment if you want to use Zero-Conv
|
||||
# torch.nn.init.zeros_(self.conditioning2[-2].weight) # zero conv
|
||||
|
||||
self.depth = depth
|
||||
self.depth = depth # 1~3
|
||||
self.cond_emb = None
|
||||
self.batch_cond_only = False
|
||||
self.use_zeros_for_batch_uncond = False
|
||||
self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference
|
||||
self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0
|
||||
|
||||
def set_cond_embs(self, cond_embs_4d, cond_embs_3d):
|
||||
r"""
|
||||
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
|
||||
/ call the model inside, so if necessary, surround it with torch.no_grad()
|
||||
"""
|
||||
# conv2dとlinearでshapeが違うので必要な方を選択 / select the required one because the shape is different for conv2d and linear
|
||||
cond_embs = cond_embs_4d if self.is_conv2d else cond_embs_3d
|
||||
|
||||
cond_emb = cond_embs[self.depth - 1]
|
||||
|
||||
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
|
||||
self.cond_emb = self.conditioning1(cond_emb)
|
||||
|
||||
def set_batch_cond_only(self, cond_only, zeros):
|
||||
@@ -65,32 +98,39 @@ class LoRAModuleControlNet(LoRAModule):
|
||||
if self.cond_emb is None:
|
||||
return self.org_forward(x)
|
||||
|
||||
# LoRA
|
||||
# LoRA-Down
|
||||
lx = x
|
||||
if self.batch_cond_only:
|
||||
lx = lx[1::2] # cond only
|
||||
lx = lx[1::2] # cond only in inference
|
||||
|
||||
lx = self.lora_down(lx)
|
||||
|
||||
if self.dropout is not None and self.training:
|
||||
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||
|
||||
# conditioning image
|
||||
# conditioning image embeddingを結合 / combine conditioning image embedding
|
||||
cx = self.cond_emb
|
||||
|
||||
if not self.batch_cond_only and lx.shape[0] // 2 == cx.shape[0]: # inference only
|
||||
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
|
||||
if self.use_zeros_for_batch_uncond:
|
||||
cx[0::2] = 0.0 # uncond is zero
|
||||
# print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}")
|
||||
|
||||
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
|
||||
# we expect that it will mix well by combining in the channel direction instead of adding
|
||||
cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2)
|
||||
cx = self.conditioning2(cx)
|
||||
|
||||
lx = lx + cx
|
||||
lx = lx + cx # lxはresidual的に加算される / lx is added residually
|
||||
|
||||
# LoRA-Up
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
# call original module
|
||||
x = self.org_forward(x)
|
||||
|
||||
# add LoRA
|
||||
if self.batch_cond_only:
|
||||
x[1::2] += lx * self.multiplier * self.scale
|
||||
else:
|
||||
@@ -127,6 +167,7 @@ class LoRAControlNet(torch.nn.Module):
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
|
||||
if is_linear or (is_conv2d and not SKIP_CONV2D):
|
||||
# block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
|
||||
# block index to depth: depth is using to calculate conditioning size and channels
|
||||
block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
|
||||
index1 = int(index1)
|
||||
@@ -155,7 +196,10 @@ class LoRAControlNet(torch.nn.Module):
|
||||
if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
|
||||
continue
|
||||
|
||||
# skip time emb or clip emb
|
||||
# time embは適用外とする
|
||||
# attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
|
||||
# time emb is not applied
|
||||
# attn2 conditioning (input from CLIP) cannot be applied because the shape is different
|
||||
if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)):
|
||||
continue
|
||||
|
||||
@@ -191,8 +235,22 @@ class LoRAControlNet(torch.nn.Module):
|
||||
print(f"create ControlNet LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
# conditioning image embedding
|
||||
|
||||
# control画像そのままではLoRA的モジュールの入力にはサイズもチャネルも扱いにくいので、
|
||||
# 適切な潜在空間に変換する。ここでは、conditioning image embeddingと呼ぶ
|
||||
# ただcontrol画像自体にはあまり情報量はないので、conditioning image embeddingはわりと小さくてよいはず
|
||||
# また、conditioning image embeddingは、各LoRA的モジュールでさらに個別に学習する
|
||||
# depthに応じて3つのサイズを用意する
|
||||
|
||||
# conditioning image embedding is converted to an appropriate latent space
|
||||
# because the size and channels of the input to the LoRA-like module are difficult to handle
|
||||
# we call it conditioning image embedding
|
||||
# however, the control image itself does not have much information, so the conditioning image embedding should be small
|
||||
# conditioning image embedding is also learned individually in each LoRA-like module
|
||||
# prepare three sizes according to depth
|
||||
|
||||
self.cond_block0 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0), # to latent size
|
||||
torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0), # to latent (from VAE) size
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
@@ -216,7 +274,7 @@ class LoRAControlNet(torch.nn.Module):
|
||||
x = self.cond_block2(x)
|
||||
x2 = x
|
||||
|
||||
x_3d = []
|
||||
x_3d = [] # for Linear
|
||||
for x0 in [x0, x1, x2]:
|
||||
# b,c,h,w -> b,h*w,c
|
||||
n, c, h, w = x0.shape
|
||||
@@ -226,6 +284,10 @@ class LoRAControlNet(torch.nn.Module):
|
||||
return [x0, x1, x2], x_3d
|
||||
|
||||
def set_cond_embs(self, cond_embs_4d, cond_embs_3d):
|
||||
r"""
|
||||
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
|
||||
/ call the model inside, so if necessary, surround it with torch.no_grad()
|
||||
"""
|
||||
for lora in self.unet_loras:
|
||||
lora.set_cond_embs(cond_embs_4d, cond_embs_3d)
|
||||
|
||||
@@ -295,6 +357,9 @@ class LoRAControlNet(torch.nn.Module):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# デバッグ用 / for debug
|
||||
|
||||
# これを指定しないとエラーが出てcond_blockが学習できない / if not specified, an error occurs and cond_block cannot be learned
|
||||
sdxl_original_unet.USE_REENTRANT = False
|
||||
|
||||
# test shape etc
|
||||
@@ -303,7 +368,7 @@ if __name__ == "__main__":
|
||||
unet.to("cuda").to(torch.float16)
|
||||
|
||||
print("create LoRA controlnet")
|
||||
control_net = LoRAControlNet(unet, 128, 64, 1)
|
||||
control_net = LoRAControlNet(unet, 64, 16, 1)
|
||||
control_net.apply_to()
|
||||
control_net.to("cuda")
|
||||
|
||||
@@ -329,7 +394,7 @@ if __name__ == "__main__":
|
||||
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
|
||||
# print("render")
|
||||
# image.format = "svg" # "png"
|
||||
# image.render("NeuralNet")
|
||||
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
|
||||
# input()
|
||||
|
||||
import bitsandbytes
|
||||
|
||||
@@ -401,8 +401,13 @@ def train(args):
|
||||
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
# conditioning image embeddingを計算する / calculate conditioning image embedding
|
||||
cond_embs_4d, cond_embs_3d = network(controlnet_image)
|
||||
|
||||
# 個別のLoRA的モジュールでさらにembeddingを計算する / calculate embedding in each LoRA-like module
|
||||
network.set_cond_embs(cond_embs_4d, cond_embs_3d)
|
||||
|
||||
# それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values
|
||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
|
||||
if args.v_parameterization:
|
||||
@@ -514,262 +519,6 @@ def train(args):
|
||||
|
||||
print("model saved.")
|
||||
|
||||
r"""
|
||||
progress_bar = tqdm(
|
||||
range(args.max_train_steps),
|
||||
smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="steps",
|
||||
)
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
clip_sample=False,
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
del train_dataset_group
|
||||
|
||||
# function for saving/removing
|
||||
def save_model(ckpt_name, model, force_sync_upload=False):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
||||
|
||||
state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict())
|
||||
|
||||
if save_dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(ckpt_file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
|
||||
save_file(state_dict, ckpt_file)
|
||||
else:
|
||||
torch.save(state_dict, ckpt_file)
|
||||
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||
|
||||
def remove_model(old_ckpt_name):
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
if is_main_process:
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(controlnet):
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
else:
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
||||
elif args.multires_noise_iterations:
|
||||
noise = pyramid_noise_like(
|
||||
noise,
|
||||
latents.device,
|
||||
args.multires_noise_iterations,
|
||||
args.multires_noise_discount,
|
||||
)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
noise_scheduler.config.num_train_timesteps,
|
||||
(b_size,),
|
||||
device=latents.device,
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
down_block_res_samples, mid_block_res_sample = controlnet(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
controlnet_cond=controlnet_image,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples],
|
||||
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
|
||||
).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = controlnet.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
None,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
unet,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
save_model(
|
||||
ckpt_name,
|
||||
accelerator.unwrap_model(controlnet),
|
||||
)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||
|
||||
remove_step_no = train_util.get_remove_step_no(args, global_step)
|
||||
if remove_step_no is not None:
|
||||
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if epoch == 0:
|
||||
loss_list.append(current_loss)
|
||||
else:
|
||||
loss_total -= loss_list[step]
|
||||
loss_list[step] = current_loss
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / len(loss_list)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 指定エポックごとにモデルを保存
|
||||
if args.save_every_n_epochs is not None:
|
||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||
if is_main_process and saving:
|
||||
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
||||
save_model(ckpt_name, accelerator.unwrap_model(controlnet))
|
||||
|
||||
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
||||
if remove_epoch_no is not None:
|
||||
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch + 1,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
unet,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
|
||||
# end of epoch
|
||||
if is_main_process:
|
||||
controlnet = accelerator.unwrap_model(controlnet)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if is_main_process and args.save_state:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
# del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく
|
||||
|
||||
if is_main_process:
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, controlnet, force_sync_upload=True)
|
||||
|
||||
print("model saved.")
|
||||
"""
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
Reference in New Issue
Block a user