add comments

This commit is contained in:
Kohya S
2023-08-17 20:49:39 +09:00
parent 809fca0be9
commit 1e52fe6e09
2 changed files with 85 additions and 271 deletions

View File

@@ -5,12 +5,25 @@ from networks.lora import LoRAModule, LoRANetwork
from library import sdxl_original_unet
# input_blocksに適用するかどうか / if True, input_blocks are not applied
SKIP_INPUT_BLOCKS = False
# output_blocksに適用するかどうか / if True, output_blocks are not applied
SKIP_OUTPUT_BLOCKS = True
# conv2dに適用するかどうか / if True, conv2d are not applied
SKIP_CONV2D = False
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored
ATTN1_ETC_ONLY = False # True
TRANSFORMER_MAX_BLOCK_INDEX = None # 3 # None # 2 # None for all blocks
# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
# if True, only transformer_blocks are applied, and ResBlocks are not applied
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
ATTN1_ETC_ONLY = False # True
# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
# max index of transformer_blocks. if None, apply to all transformer_blocks
TRANSFORMER_MAX_BLOCK_INDEX = None
class LoRAModuleControlNet(LoRAModule):
@@ -19,6 +32,16 @@ class LoRAModuleControlNet(LoRAModule):
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
self.cond_emb_dim = cond_emb_dim
# conditioning1は、conditioning image embeddingを、各LoRA的モジュールでさらに学習する。ここはtimestepごとに呼ばれない
# それぞれのモジュールで異なる表現を学習することを期待している
# conditioning1 learns conditioning image embedding in each LoRA-like module. this is not called for each timestep
# we expect to learn different representations in each module
# conditioning2は、conditioning1の出力とLoRAの出力を結合し、LoRAの出力に加算する。timestepごとに呼ばれる
# conditioning image embeddingとU-Netの出力を合わせて学ぶことで、conditioningに応じたU-Netの調整を行う
# conditioning2 combines the output of conditioning1 and the output of LoRA, and adds it to the output of LoRA. this is called for each timestep
# by learning the output of conditioning image embedding and U-Net together, we adjust U-Net according to conditioning
if self.is_conv2d:
self.conditioning1 = torch.nn.Sequential(
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0),
@@ -45,16 +68,26 @@ class LoRAModuleControlNet(LoRAModule):
torch.nn.Linear(cond_emb_dim, lora_dim),
torch.nn.ReLU(inplace=True),
)
# Zero-Convにするならコメントを外す / uncomment if you want to use Zero-Conv
# torch.nn.init.zeros_(self.conditioning2[-2].weight) # zero conv
self.depth = depth
self.depth = depth # 1~3
self.cond_emb = None
self.batch_cond_only = False
self.use_zeros_for_batch_uncond = False
self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference
self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0
def set_cond_embs(self, cond_embs_4d, cond_embs_3d):
r"""
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
/ call the model inside, so if necessary, surround it with torch.no_grad()
"""
# conv2dとlinearでshapeが違うので必要な方を選択 / select the required one because the shape is different for conv2d and linear
cond_embs = cond_embs_4d if self.is_conv2d else cond_embs_3d
cond_emb = cond_embs[self.depth - 1]
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
self.cond_emb = self.conditioning1(cond_emb)
def set_batch_cond_only(self, cond_only, zeros):
@@ -65,32 +98,39 @@ class LoRAModuleControlNet(LoRAModule):
if self.cond_emb is None:
return self.org_forward(x)
# LoRA
# LoRA-Down
lx = x
if self.batch_cond_only:
lx = lx[1::2] # cond only
lx = lx[1::2] # cond only in inference
lx = self.lora_down(lx)
if self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(lx, p=self.dropout)
# conditioning image
# conditioning image embeddingを結合 / combine conditioning image embedding
cx = self.cond_emb
if not self.batch_cond_only and lx.shape[0] // 2 == cx.shape[0]: # inference only
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
if self.use_zeros_for_batch_uncond:
cx[0::2] = 0.0 # uncond is zero
# print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}")
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
# we expect that it will mix well by combining in the channel direction instead of adding
cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2)
cx = self.conditioning2(cx)
lx = lx + cx
lx = lx + cx # lxはresidual的に加算される / lx is added residually
# LoRA-Up
lx = self.lora_up(lx)
# call original module
x = self.org_forward(x)
# add LoRA
if self.batch_cond_only:
x[1::2] += lx * self.multiplier * self.scale
else:
@@ -127,6 +167,7 @@ class LoRAControlNet(torch.nn.Module):
is_conv2d = child_module.__class__.__name__ == "Conv2d"
if is_linear or (is_conv2d and not SKIP_CONV2D):
# block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
# block index to depth: depth is using to calculate conditioning size and channels
block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
index1 = int(index1)
@@ -155,7 +196,10 @@ class LoRAControlNet(torch.nn.Module):
if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
continue
# skip time emb or clip emb
# time embは適用外とする
# attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
# time emb is not applied
# attn2 conditioning (input from CLIP) cannot be applied because the shape is different
if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)):
continue
@@ -191,8 +235,22 @@ class LoRAControlNet(torch.nn.Module):
print(f"create ControlNet LoRA for U-Net: {len(self.unet_loras)} modules.")
# conditioning image embedding
# control画像そのままではLoRA的モジュールの入力にはサイズもチャネルも扱いにくいので、
# 適切な潜在空間に変換する。ここでは、conditioning image embeddingと呼ぶ
# ただcontrol画像自体にはあまり情報量はないので、conditioning image embeddingはわりと小さくてよいはず
# また、conditioning image embeddingは、各LoRA的モジュールでさらに個別に学習する
# depthに応じて3つのサイズを用意する
# conditioning image embedding is converted to an appropriate latent space
# because the size and channels of the input to the LoRA-like module are difficult to handle
# we call it conditioning image embedding
# however, the control image itself does not have much information, so the conditioning image embedding should be small
# conditioning image embedding is also learned individually in each LoRA-like module
# prepare three sizes according to depth
self.cond_block0 = torch.nn.Sequential(
torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0), # to latent size
torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0), # to latent (from VAE) size
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=3, stride=2, padding=1),
torch.nn.ReLU(inplace=True),
@@ -216,7 +274,7 @@ class LoRAControlNet(torch.nn.Module):
x = self.cond_block2(x)
x2 = x
x_3d = []
x_3d = [] # for Linear
for x0 in [x0, x1, x2]:
# b,c,h,w -> b,h*w,c
n, c, h, w = x0.shape
@@ -226,6 +284,10 @@ class LoRAControlNet(torch.nn.Module):
return [x0, x1, x2], x_3d
def set_cond_embs(self, cond_embs_4d, cond_embs_3d):
r"""
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
/ call the model inside, so if necessary, surround it with torch.no_grad()
"""
for lora in self.unet_loras:
lora.set_cond_embs(cond_embs_4d, cond_embs_3d)
@@ -295,6 +357,9 @@ class LoRAControlNet(torch.nn.Module):
if __name__ == "__main__":
# デバッグ用 / for debug
# これを指定しないとエラーが出てcond_blockが学習できない / if not specified, an error occurs and cond_block cannot be learned
sdxl_original_unet.USE_REENTRANT = False
# test shape etc
@@ -303,7 +368,7 @@ if __name__ == "__main__":
unet.to("cuda").to(torch.float16)
print("create LoRA controlnet")
control_net = LoRAControlNet(unet, 128, 64, 1)
control_net = LoRAControlNet(unet, 64, 16, 1)
control_net.apply_to()
control_net.to("cuda")
@@ -329,7 +394,7 @@ if __name__ == "__main__":
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
# print("render")
# image.format = "svg" # "png"
# image.render("NeuralNet")
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
# input()
import bitsandbytes