rename and update

This commit is contained in:
ykume
2023-08-19 18:44:40 +09:00
parent 62fa4734fe
commit fef7eb73ad
4 changed files with 253 additions and 2824 deletions

View File

@@ -1,7 +1,6 @@
import os import os
from typing import Optional, List, Type from typing import Optional, List, Type
import torch import torch
from networks.lora import LoRAModule, LoRANetwork
from library import sdxl_original_unet from library import sdxl_original_unet
@@ -21,6 +20,9 @@ TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not
# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc. # Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
ATTN1_2_ONLY = True ATTN1_2_ONLY = True
# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
ATTN_QKV_ONLY = True
# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2 # Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY # ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
ATTN1_ETC_ONLY = False # True ATTN1_ETC_ONLY = False # True
@@ -30,126 +32,159 @@ ATTN1_ETC_ONLY = False # True
TRANSFORMER_MAX_BLOCK_INDEX = None TRANSFORMER_MAX_BLOCK_INDEX = None
class LoRAModuleControlNet(LoRAModule): class LLLiteModule(torch.nn.Module):
def __init__(self, depth, cond_emb_dim, name, org_module, multiplier, lora_dim, alpha, dropout=None): def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None):
super().__init__(name, org_module, multiplier, lora_dim, alpha, dropout=dropout) super().__init__()
self.is_conv2d = org_module.__class__.__name__ == "Conv2d" self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
self.lllite_name = name
self.cond_emb_dim = cond_emb_dim self.cond_emb_dim = cond_emb_dim
self.org_module = [org_module]
# conditioning1は、conditioning image embeddingを、各LoRA的モジュールでさらに学習する。ここはtimestepごとに呼ばれない self.dropout = dropout
# それぞれのモジュールで異なる表現を学習することを期待している
# conditioning1 learns conditioning image embedding in each LoRA-like module. this is not called for each timestep
# we expect to learn different representations in each module
# conditioning2は、conditioning1の出力とLoRAの出力を結合し、LoRAの出力に加算する。timestepごとに呼ばれる
# conditioning image embeddingとU-Netの出力を合わせて学ぶことで、conditioningに応じたU-Netの調整を行う
# conditioning2 combines the output of conditioning1 and the output of LoRA, and adds it to the output of LoRA. this is called for each timestep
# by learning the output of conditioning image embedding and U-Net together, we adjust U-Net according to conditioning
if self.is_conv2d: if self.is_conv2d:
self.conditioning1 = torch.nn.Sequential( in_dim = org_module.in_channels
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0), else:
torch.nn.ReLU(inplace=True), in_dim = org_module.in_features
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0),
# conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
# conditioning1 embeds conditioning image. it is not called for each timestep
modules = []
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size
if depth == 1:
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
elif depth == 2:
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
elif depth == 3:
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
self.conditioning1 = torch.nn.Sequential(*modules)
# downで入力の次元数を削減する。LoRAにヒントを得ていることにする
# midでconditioning image embeddingと入力を結合する
# upで元の次元数に戻す
# これらはtimestepごとに呼ばれる
# reduce the number of input dimensions with down. inspired by LoRA
# combine conditioning image embedding and input with mid
# restore to the original dimension with up
# these are called for each timestep
if self.is_conv2d:
self.down = torch.nn.Sequential(
torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
torch.nn.ReLU(inplace=True), torch.nn.ReLU(inplace=True),
) )
self.conditioning2 = torch.nn.Sequential( self.mid = torch.nn.Sequential(
torch.nn.Conv2d(lora_dim + cond_emb_dim, cond_emb_dim, kernel_size=1, stride=1, padding=0), torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(cond_emb_dim, lora_dim, kernel_size=1, stride=1, padding=0),
torch.nn.ReLU(inplace=True), torch.nn.ReLU(inplace=True),
) )
self.up = torch.nn.Sequential(
torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
)
else: else:
self.conditioning1 = torch.nn.Sequential( # midの前にconditioningをreshapeすること / reshape conditioning before mid
torch.nn.Linear(cond_emb_dim, cond_emb_dim), self.down = torch.nn.Sequential(
torch.nn.ReLU(inplace=True), torch.nn.Linear(in_dim, mlp_dim),
torch.nn.Linear(cond_emb_dim, cond_emb_dim),
torch.nn.ReLU(inplace=True), torch.nn.ReLU(inplace=True),
) )
self.conditioning2 = torch.nn.Sequential( self.mid = torch.nn.Sequential(
torch.nn.Linear(lora_dim + cond_emb_dim, cond_emb_dim), torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(cond_emb_dim, lora_dim),
torch.nn.ReLU(inplace=True), torch.nn.ReLU(inplace=True),
) )
self.up = torch.nn.Sequential(
torch.nn.Linear(mlp_dim, in_dim),
)
# Zero-Convにするならコメントを外す / uncomment if you want to use Zero-Conv # Zero-Convにする / set to Zero-Conv
# torch.nn.init.zeros_(self.conditioning2[-2].weight) # zero conv torch.nn.init.zeros_(self.up[0].weight) # zero conv
self.depth = depth # 1~3 self.depth = depth # 1~3
self.cond_emb = None self.cond_emb = None
self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference
self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0 self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0
def set_cond_embs(self, cond_embs_4d, cond_embs_3d): # batch_cond_onlyとuse_zeros_for_batch_uncondはどちらも適用すると生成画像の色味がおかしくなるので実際には使えそうにない
# Controlの種類によっては使えるかも
# both batch_cond_only and use_zeros_for_batch_uncond make the color of the generated image strange, so it doesn't seem to be usable in practice
# it may be available depending on the type of Control
def set_cond_image(self, cond_image):
r""" r"""
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
/ call the model inside, so if necessary, surround it with torch.no_grad() / call the model inside, so if necessary, surround it with torch.no_grad()
""" """
# conv2dとlinearでshapeが違うので必要な方を選択 / select the required one because the shape is different for conv2d and linear
cond_embs = cond_embs_4d if self.is_conv2d else cond_embs_3d
cond_emb = cond_embs[self.depth - 1]
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
self.cond_emb = self.conditioning1(cond_emb) cx = self.conditioning1(cond_image)
if not self.is_conv2d:
# reshape / b,c,h,w -> b,h*w,c
n, c, h, w = cx.shape
cx = cx.view(n, c, h * w).permute(0, 2, 1)
self.cond_emb = cx
def set_batch_cond_only(self, cond_only, zeros): def set_batch_cond_only(self, cond_only, zeros):
self.batch_cond_only = cond_only self.batch_cond_only = cond_only
self.use_zeros_for_batch_uncond = zeros self.use_zeros_for_batch_uncond = zeros
def apply_to(self):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def forward(self, x): def forward(self, x):
if self.cond_emb is None: r"""
return self.org_forward(x) 学習用の便利forward元のモジュールのforwardを呼び出す
/ convenient forward for training. call the forward of the original module
# LoRA-Down """
lx = x
if self.batch_cond_only:
lx = lx[1::2] # cond only in inference
lx = self.lora_down(lx)
if self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(lx, p=self.dropout)
# conditioning image embeddingを結合 / combine conditioning image embedding
cx = self.cond_emb cx = self.cond_emb
if not self.batch_cond_only and lx.shape[0] // 2 == cx.shape[0]: # inference only if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
if self.use_zeros_for_batch_uncond: if self.use_zeros_for_batch_uncond:
cx[0::2] = 0.0 # uncond is zero cx[0::2] = 0.0 # uncond is zero
# print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}") # print(f"C {self.lllite_name}, lx.shape={lx.shape}, cx.shape={cx.shape}")
# downで入力の次元数を削減し、conditioning image embeddingと結合する
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
# down reduces the number of input dimensions and combines it with conditioning image embedding
# we expect that it will mix well by combining in the channel direction instead of adding # we expect that it will mix well by combining in the channel direction instead of adding
cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2)
cx = self.conditioning2(cx)
lx = lx + cx # lxはresidual的に加算される / lx is added residually cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2)
cx = self.mid(cx)
# LoRA-Up if self.dropout is not None and self.training:
lx = self.lora_up(lx) cx = torch.nn.functional.dropout(cx, p=self.dropout)
# call original module cx = self.up(cx)
x = self.org_forward(x)
# add LoRA # residualを加算する / add residual
if self.batch_cond_only: if self.batch_cond_only:
x[1::2] += lx * self.multiplier * self.scale x[1::2] += cx
else: else:
x += lx * self.multiplier * self.scale # to_outを対象とすると、cloneがないと次のエラーが出る / if to_out is the target, the following error will occur without clone
# RuntimeError: Output 0 of ReshapeAliasBackward0 is a view and is being modified inplace.
# This view was created inside a custom Function ...
# x = x.clone()
x += cx
x = self.org_forward(x) # ここで元のモジュールを呼び出す / call the original module here
return x return x
class LoRAControlNet(torch.nn.Module): class ControlNetLLLite(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
def __init__( def __init__(
self, self,
unet: sdxl_original_unet.SdxlUNet2DConditionModel, unet: sdxl_original_unet.SdxlUNet2DConditionModel,
cond_emb_dim: int = 16, cond_emb_dim: int = 16,
lora_dim: int = 16, mlp_dim: int = 16,
alpha: float = 1,
dropout: Optional[float] = None, dropout: Optional[float] = None,
varbose: Optional[bool] = False, varbose: Optional[bool] = False,
) -> None: ) -> None:
@@ -161,9 +196,9 @@ class LoRAControlNet(torch.nn.Module):
target_replace_modules: List[torch.nn.Module], target_replace_modules: List[torch.nn.Module],
module_class: Type[object], module_class: Type[object],
) -> List[torch.nn.Module]: ) -> List[torch.nn.Module]:
prefix = LoRANetwork.LORA_PREFIX_UNET prefix = "lllite_unet"
loras = [] modules = []
for name, module in root_module.named_modules(): for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules: if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules(): for child_name, child_module in module.named_modules():
@@ -190,13 +225,13 @@ class LoRAControlNet(torch.nn.Module):
else: else:
raise NotImplementedError() raise NotImplementedError()
lora_name = prefix + "." + name + "." + child_name lllite_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_") lllite_name = lllite_name.replace(".", "_")
if TRANSFORMER_MAX_BLOCK_INDEX is not None: if TRANSFORMER_MAX_BLOCK_INDEX is not None:
p = lora_name.find("transformer_blocks") p = lllite_name.find("transformer_blocks")
if p >= 0: if p >= 0:
tf_index = int(lora_name[p:].split("_")[2]) tf_index = int(lllite_name[p:].split("_")[2])
if tf_index > TRANSFORMER_MAX_BLOCK_INDEX: if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
continue continue
@@ -204,104 +239,63 @@ class LoRAControlNet(torch.nn.Module):
# attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
# time emb is not applied # time emb is not applied
# attn2 conditioning (input from CLIP) cannot be applied because the shape is different # attn2 conditioning (input from CLIP) cannot be applied because the shape is different
if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)): if "emb_layers" in lllite_name or (
"attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
):
continue continue
if ATTN1_2_ONLY: if ATTN1_2_ONLY:
if not ("attn1" in lora_name or "attn2" in lora_name): if not ("attn1" in lllite_name or "attn2" in lllite_name):
continue continue
if ATTN_QKV_ONLY:
if "to_out" in lllite_name:
continue
if ATTN1_ETC_ONLY: if ATTN1_ETC_ONLY:
if "proj_out" in lora_name: if "proj_out" in lllite_name:
pass pass
elif "attn1" in lora_name and ("to_k" in lora_name or "to_v" in lora_name or "to_out" in lora_name): elif "attn1" in lllite_name and (
"to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
):
pass pass
elif "ff_net_2" in lora_name: elif "ff_net_2" in lllite_name:
pass pass
else: else:
continue continue
lora = module_class( module = module_class(
depth, depth,
cond_emb_dim, cond_emb_dim,
lora_name, lllite_name,
child_module, child_module,
1.0, mlp_dim,
lora_dim,
alpha,
dropout=dropout, dropout=dropout,
) )
loras.append(lora) modules.append(module)
return loras return modules
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE
if not TRANSFORMER_ONLY: if not TRANSFORMER_ONLY:
target_modules = target_modules + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
# create module instances # create module instances
self.unet_loras: List[LoRAModuleControlNet] = create_modules(unet, target_modules, LoRAModuleControlNet) self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
print(f"create ControlNet LoRA for U-Net: {len(self.unet_loras)} modules.") print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
# conditioning image embedding
# control画像そのままではLoRA的モジュールの入力にはサイズもチャネルも扱いにくいので、
# 適切な潜在空間に変換する。ここでは、conditioning image embeddingと呼ぶ
# ただcontrol画像自体にはあまり情報量はないので、conditioning image embeddingはわりと小さくてよいはず
# また、conditioning image embeddingは、各LoRA的モジュールでさらに個別に学習する
# depthに応じて3つのサイズを用意する
# conditioning image embedding is converted to an appropriate latent space
# because the size and channels of the input to the LoRA-like module are difficult to handle
# we call it conditioning image embedding
# however, the control image itself does not have much information, so the conditioning image embedding should be small
# conditioning image embedding is also learned individually in each LoRA-like module
# prepare three sizes according to depth
self.cond_block0 = torch.nn.Sequential(
torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0), # to latent (from VAE) size
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=3, stride=2, padding=1),
torch.nn.ReLU(inplace=True),
)
self.cond_block1 = torch.nn.Sequential(
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=2, padding=1),
torch.nn.ReLU(inplace=True),
)
self.cond_block2 = torch.nn.Sequential(
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=2, padding=1),
torch.nn.ReLU(inplace=True),
)
def forward(self, x): def forward(self, x):
x = self.cond_block0(x) return x # dummy
x0 = x
x = self.cond_block1(x)
x1 = x
x = self.cond_block2(x)
x2 = x
x_3d = [] # for Linear def set_cond_image(self, cond_image):
for x0 in [x0, x1, x2]:
# b,c,h,w -> b,h*w,c
n, c, h, w = x0.shape
x0 = x0.view(n, c, h * w).permute(0, 2, 1)
x_3d.append(x0)
return [x0, x1, x2], x_3d
def set_cond_embs(self, cond_embs_4d, cond_embs_3d):
r""" r"""
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
/ call the model inside, so if necessary, surround it with torch.no_grad() / call the model inside, so if necessary, surround it with torch.no_grad()
""" """
for lora in self.unet_loras: for module in self.unet_modules:
lora.set_cond_embs(cond_embs_4d, cond_embs_3d) module.set_cond_image(cond_image)
def set_batch_cond_only(self, cond_only, zeros): def set_batch_cond_only(self, cond_only, zeros):
for lora in self.unet_loras: for module in self.unet_modules:
lora.set_batch_cond_only(cond_only, zeros) module.set_batch_cond_only(cond_only, zeros)
def load_weights(self, file): def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors": if os.path.splitext(file)[1] == ".safetensors":
@@ -315,10 +309,10 @@ class LoRAControlNet(torch.nn.Module):
return info return info
def apply_to(self): def apply_to(self):
print("applying LoRA for U-Net...") print("applying LLLite for U-Net...")
for lora in self.unet_loras: for module in self.unet_modules:
lora.apply_to() module.apply_to()
self.add_module(lora.lora_name, lora) self.add_module(module.lllite_name, module)
# マージできるかどうかを返す # マージできるかどうかを返す
def is_mergeable(self): def is_mergeable(self):
@@ -367,16 +361,15 @@ class LoRAControlNet(torch.nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
# デバッグ用 / for debug # デバッグ用 / for debug
# これを指定しないとエラーが出てcond_blockが学習できない / if not specified, an error occurs and cond_block cannot be learned # sdxl_original_unet.USE_REENTRANT = False
sdxl_original_unet.USE_REENTRANT = False
# test shape etc # test shape etc
print("create unet") print("create unet")
unet = sdxl_original_unet.SdxlUNet2DConditionModel() unet = sdxl_original_unet.SdxlUNet2DConditionModel()
unet.to("cuda").to(torch.float16) unet.to("cuda").to(torch.float16)
print("create LoRA controlnet") print("create ControlNet-LLLite")
control_net = LoRAControlNet(unet, 64, 32, 1) control_net = ControlNetLLLite(unet, 32, 64)
control_net.apply_to() control_net.apply_to()
control_net.to("cuda") control_net.to("cuda")
@@ -414,6 +407,7 @@ if __name__ == "__main__":
print("start training") print("start training")
steps = 10 steps = 10
sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
for step in range(steps): for step in range(steps):
print(f"step {step}") print(f"step {step}")
@@ -425,8 +419,7 @@ if __name__ == "__main__":
y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
with torch.cuda.amp.autocast(enabled=True): with torch.cuda.amp.autocast(enabled=True):
cond_embs_4d, cond_embs_3d = control_net(conditioning_image) control_net.set_cond_image(conditioning_image)
control_net.set_cond_embs(cond_embs_4d, cond_embs_3d)
output = unet(x, t, ctx, y) output = unet(x, t, ctx, y)
target = torch.randn_like(output) target = torch.randn_like(output)
@@ -436,3 +429,8 @@ if __name__ == "__main__":
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
print(sample_param)
# from safetensors.torch import save_file
# save_file(control_net.state_dict(), "logs/control_net.safetensors")

View File

@@ -47,10 +47,9 @@ import library.train_util as train_util
import library.sdxl_model_util as sdxl_model_util import library.sdxl_model_util as sdxl_model_util
import library.sdxl_train_util as sdxl_train_util import library.sdxl_train_util as sdxl_train_util
from networks.lora import LoRANetwork from networks.lora import LoRANetwork
import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo
from library.sdxl_original_unet import SdxlUNet2DConditionModel from library.sdxl_original_unet import SdxlUNet2DConditionModel
from library.original_unet import FlashAttentionFunction from library.original_unet import FlashAttentionFunction
from networks.control_net_lllite import ControlNetLLLite
# scheduler: # scheduler:
SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_START = 0.00085
@@ -327,7 +326,7 @@ class PipelineLike:
self.token_replacements_list.append({}) self.token_replacements_list.append({})
# ControlNet # not supported yet # ControlNet # not supported yet
self.control_nets: List[ControlNetInfo] = [] self.control_nets: List[ControlNetLLLite] = []
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない
# Textual Inversion # Textual Inversion
@@ -392,6 +391,7 @@ class PipelineLike:
is_cancelled_callback: Optional[Callable[[], bool]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
img2img_noise=None, img2img_noise=None,
clip_guide_images=None,
**kwargs, **kwargs,
): ):
# TODO support secondary prompt # TODO support secondary prompt
@@ -496,11 +496,16 @@ class PipelineLike:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
if self.control_nets: if self.control_nets:
# ControlNetのhintにguide imageを流用する
if isinstance(clip_guide_images, PIL.Image.Image): if isinstance(clip_guide_images, PIL.Image.Image):
clip_guide_images = [clip_guide_images] clip_guide_images = [clip_guide_images]
if isinstance(clip_guide_images[0], PIL.Image.Image):
clip_guide_images = [preprocess_image(im) for im in clip_guide_images]
clip_guide_images = torch.cat(clip_guide_images)
if isinstance(clip_guide_images, list):
clip_guide_images = torch.stack(clip_guide_images)
# ControlNetのhintにguide imageを流用する clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype)
# 前処理はControlNet側で行う
# create size embs # create size embs
if original_height is None: if original_height is None:
@@ -654,35 +659,47 @@ class PipelineLike:
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
if self.control_nets: if self.control_nets:
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) # guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
if self.control_net_enabled:
for control_net in self.control_nets:
with torch.no_grad():
control_net.set_cond_image(clip_guide_images)
else:
for control_net in self.control_nets:
control_net.set_cond_image(None)
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # # disable control net if ratio is set
if self.control_nets and self.control_net_enabled: # if self.control_nets and self.control_net_enabled:
if reginonal_network: # pass # TODO
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
else:
text_emb_last = text_embeddings
# not working yet # predict the noise residual
noise_pred = original_control_net.call_unet_and_control_net( # TODO Diffusers' ControlNet
i, # if self.control_nets and self.control_net_enabled:
num_latent_input, # if reginonal_network:
self.unet, # num_sub_and_neg_prompts = len(text_embeddings) // batch_size
self.control_nets, # text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
guided_hints, # else:
i / len(timesteps), # text_emb_last = text_embeddings
latent_model_input,
t, # # not working yet
text_emb_last, # noise_pred = original_control_net.call_unet_and_control_net(
).sample # i,
else: # num_latent_input,
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) # self.unet,
# self.control_nets,
# guided_hints,
# i / len(timesteps),
# latent_model_input,
# t,
# text_emb_last,
# ).sample
# else:
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
@@ -1550,16 +1567,40 @@ def main(args):
upscaler.to(dtype).to(device) upscaler.to(dtype).to(device)
# ControlNetの処理 # ControlNetの処理
control_nets: List[ControlNetInfo] = [] control_nets: List[ControlNetLLLite] = []
if args.control_net_models: # if args.control_net_models:
for i, model in enumerate(args.control_net_models): # for i, model in enumerate(args.control_net_models):
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] # prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] # weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] # ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
ctrl_unet, ctrl_net = original_control_net.load_control_net(False, unet, model) # ctrl_unet, ctrl_net = original_control_net.load_control_net(False, unet, model)
prep = original_control_net.load_preprocess(prep_type) # prep = original_control_net.load_preprocess(prep_type)
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) # control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
if args.control_net_lllite_models:
for i, model_file in enumerate(args.control_net_lllite_models):
print(f"loading ControlNet-LLLite: {model_file}")
from safetensors.torch import load_file
state_dict = load_file(model_file)
mlp_dim = None
cond_emb_dim = None
for key, value in state_dict.items():
if mlp_dim is None and "down.0.weight" in key:
mlp_dim = value.shape[0]
elif cond_emb_dim is None and "conditioning1.0" in key:
cond_emb_dim = value.shape[0] * 2
if mlp_dim is not None and cond_emb_dim is not None:
break
assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}"
control_net = ControlNetLLLite(unet, cond_emb_dim, mlp_dim)
control_net.apply_to()
control_net.load_state_dict(state_dict)
control_net.to(dtype).to(device)
control_net.set_batch_cond_only(False, False)
control_nets.append(control_net)
if args.opt_channels_last: if args.opt_channels_last:
print(f"set optimizing: channels last") print(f"set optimizing: channels last")
@@ -1572,8 +1613,9 @@ def main(args):
network.to(memory_format=torch.channels_last) network.to(memory_format=torch.channels_last)
for cn in control_nets: for cn in control_nets:
cn.unet.to(memory_format=torch.channels_last) cn.to(memory_format=torch.channels_last)
cn.net.to(memory_format=torch.channels_last) # cn.unet.to(memory_format=torch.channels_last)
# cn.net.to(memory_format=torch.channels_last)
pipe = PipelineLike( pipe = PipelineLike(
device, device,
@@ -2573,20 +2615,23 @@ def setup_parser() -> argparse.ArgumentParser:
) )
parser.add_argument( parser.add_argument(
"--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" "--control_net_lllite_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名"
)
parser.add_argument(
"--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名"
)
parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み")
parser.add_argument(
"--control_net_ratios",
type=float,
default=None,
nargs="*",
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
) )
# parser.add_argument( # parser.add_argument(
# "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名"
# )
# parser.add_argument(
# "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名"
# )
# parser.add_argument("--control_net_multiplier", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率")
# parser.add_argument(
# "--control_net_ratios",
# type=float,
# default=None,
# nargs="*",
# help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
# )
# # parser.add_argument(
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
# ) # )

File diff suppressed because it is too large Load Diff

View File

@@ -34,7 +34,7 @@ from library.custom_train_functions import (
apply_noise_offset, apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction, scale_v_prediction_loss_like_noise_prediction,
) )
import networks.lora_control_net as lora_control_net import networks.control_net_lllite as control_net_lllite
# TODO 他のスクリプトと共通化する # TODO 他のスクリプトと共通化する
@@ -176,7 +176,7 @@ def train(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# prepare ControlNet # prepare ControlNet
network = lora_control_net.LoRAControlNet(unet, args.cond_emb_dim, args.network_dim, 1, args.network_dropout) network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout)
network.apply_to() network.apply_to()
if args.network_weights is not None: if args.network_weights is not None:
@@ -242,7 +242,7 @@ def train(args):
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler unet, network, optimizer, train_dataloader, lr_scheduler
) )
network: lora_control_net.LoRAControlNet network: control_net_lllite.ControlNetLLLite
# transform DDP after prepare (train_network here only) # transform DDP after prepare (train_network here only)
unet, network = train_util.transform_models_if_DDP([unet, network]) unet, network = train_util.transform_models_if_DDP([unet, network])
@@ -311,7 +311,7 @@ def train(args):
if args.log_tracker_config is not None: if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config) init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers( accelerator.init_trackers(
"lora_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
) )
loss_list = [] loss_list = []
@@ -401,11 +401,9 @@ def train(args):
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
with accelerator.autocast(): with accelerator.autocast():
# conditioning image embeddingを計算する / calculate conditioning image embedding # conditioning imageをControlNetに渡す / pass conditioning image to ControlNet
cond_embs_4d, cond_embs_3d = network(controlnet_image) # 内部でcond_embに変換される / it will be converted to cond_emb inside
network.set_cond_image(controlnet_image)
# 個別のLoRA的モジュールでさらにembeddingを計算する / calculate embedding in each LoRA-like module
network.set_cond_embs(cond_embs_4d, cond_embs_3d)
# それらの値を使いつつ、U-Netでイズを予測する / predict noise with U-Net using those values # それらの値を使いつつ、U-Netでイズを予測する / predict noise with U-Net using those values
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
@@ -562,7 +560,7 @@ def setup_parser() -> argparse.ArgumentParser:
if __name__ == "__main__": if __name__ == "__main__":
sdxl_original_unet.USE_REENTRANT = False # sdxl_original_unet.USE_REENTRANT = False
parser = setup_parser() parser = setup_parser()