mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add lora controlnet train/gen temporarily
This commit is contained in:
@@ -7,51 +7,87 @@ from library import sdxl_original_unet
|
||||
|
||||
SKIP_OUTPUT_BLOCKS = False
|
||||
SKIP_CONV2D = False
|
||||
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored
|
||||
ATTN1_ETC_ONLY = True
|
||||
|
||||
|
||||
class LoRAModuleControlNet(LoRAModule):
|
||||
def __init__(self, depth, cond_emb_dim, name, org_module, multiplier, lora_dim, alpha, dropout=None):
|
||||
super().__init__(name, org_module, multiplier, lora_dim, alpha, dropout=dropout)
|
||||
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
||||
self.cond_emb_dim = cond_emb_dim
|
||||
|
||||
# adjust channels of conditioning image to LoRA channels
|
||||
ch = 2 ** (depth - 1) * cond_emb_dim
|
||||
if self.is_conv2d:
|
||||
self.conditioning = torch.nn.Conv2d(ch, lora_dim, kernel_size=1, stride=1, padding=0)
|
||||
self.conditioning1 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.conditioning2 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(lora_dim + cond_emb_dim, cond_emb_dim, kernel_size=1, stride=1, padding=0),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv2d(cond_emb_dim, lora_dim, kernel_size=1, stride=1, padding=0),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
else:
|
||||
self.conditioning = torch.nn.Linear(ch, lora_dim)
|
||||
torch.nn.init.zeros_(self.conditioning.weight) # zero conv/linear layer
|
||||
self.conditioning1 = torch.nn.Sequential(
|
||||
torch.nn.Linear(cond_emb_dim, cond_emb_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Linear(cond_emb_dim, cond_emb_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.conditioning2 = torch.nn.Sequential(
|
||||
torch.nn.Linear(lora_dim + cond_emb_dim, cond_emb_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Linear(cond_emb_dim, lora_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
# torch.nn.init.zeros_(self.conditioning2[-2].weight) # zero conv
|
||||
|
||||
self.depth = depth
|
||||
self.cond_emb_dim = cond_emb_dim
|
||||
self.cond_emb = None
|
||||
self.batch_cond_uncond_enabled = False
|
||||
|
||||
def set_control(self, cond_emb):
|
||||
self.cond_emb = cond_emb
|
||||
def set_cond_embs(self, cond_embs_4d, cond_embs_3d):
|
||||
cond_embs = cond_embs_4d if self.is_conv2d else cond_embs_3d
|
||||
cond_emb = cond_embs[self.depth - 1]
|
||||
self.cond_emb = self.conditioning1(cond_emb)
|
||||
|
||||
def set_batch_cond_uncond_enabled(self, enabled):
|
||||
self.batch_cond_uncond_enabled = enabled
|
||||
|
||||
def forward(self, x):
|
||||
# conditioning image embs -> LoRA channels
|
||||
cx = self.cond_emb
|
||||
if not self.is_conv2d:
|
||||
# b,c,h,w -> b,h*w,c
|
||||
n, c, h, w = cx.shape
|
||||
cx = cx.view(n, c, h * w).permute(0, 2, 1)
|
||||
# print(f"C {self.lora_name}, x.shape={x.shape}, cx.shape={cx.shape}, weight.shape={self.conditioning.weight.shape}")
|
||||
cx = self.conditioning(cx)
|
||||
if self.cond_emb is None:
|
||||
return self.org_forward(x)
|
||||
|
||||
# LoRA
|
||||
# print(f"C {self.lora_name}, x.shape={x.shape}, cx.shape={cx.shape}")
|
||||
lx = self.lora_down(x)
|
||||
lx = x
|
||||
if self.batch_cond_uncond_enabled:
|
||||
lx = lx[1::2] # cond only
|
||||
|
||||
lx = self.lora_down(lx)
|
||||
|
||||
if self.dropout is not None and self.training:
|
||||
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||
|
||||
# add conditioning
|
||||
lx = lx + cx
|
||||
# conditioning image
|
||||
cx = self.cond_emb
|
||||
# print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}")
|
||||
|
||||
cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2)
|
||||
cx = self.conditioning2(cx)
|
||||
|
||||
lx = lx + cx
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
x = self.org_forward(x) + lx * self.multiplier * self.scale
|
||||
x = self.org_forward(x)
|
||||
|
||||
if self.batch_cond_uncond_enabled:
|
||||
x[1::2] += lx * self.multiplier * self.scale
|
||||
else:
|
||||
x += lx * self.multiplier * self.scale
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@@ -106,6 +142,16 @@ class LoRAControlNet(torch.nn.Module):
|
||||
if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)):
|
||||
continue
|
||||
|
||||
if ATTN1_ETC_ONLY:
|
||||
if "proj_out" in lora_name:
|
||||
pass
|
||||
elif "attn1" in lora_name and ("to_k" in lora_name or "to_v" in lora_name or "to_out" in lora_name):
|
||||
pass
|
||||
elif "ff_net_2" in lora_name:
|
||||
pass
|
||||
else:
|
||||
continue
|
||||
|
||||
lora = module_class(
|
||||
depth,
|
||||
cond_emb_dim,
|
||||
@@ -119,52 +165,56 @@ class LoRAControlNet(torch.nn.Module):
|
||||
loras.append(lora)
|
||||
return loras
|
||||
|
||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
if not TRANSFORMER_ONLY:
|
||||
target_modules = target_modules + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
# create module instances
|
||||
self.unet_loras: List[LoRAModuleControlNet] = create_modules(unet, target_modules, LoRAModuleControlNet)
|
||||
print(f"create ControlNet LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
# stem for conditioning image
|
||||
self.cond_stem = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(3, cond_emb_dim, kernel_size=4, stride=4, padding=0),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
# embs for each depth
|
||||
# conditioning image embedding
|
||||
self.cond_block0 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0), # to latent size
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.cond_block1 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim * 2, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.cond_block2 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(cond_emb_dim * 2, cond_emb_dim * 4, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.cond_block3 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(cond_emb_dim * 4, cond_emb_dim * 8, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
# forawrdでなくset_controlに入れてもやはり動かない
|
||||
def forward(self, x):
|
||||
cx = self.cond_stem(x)
|
||||
cx = self.cond_block0(cx)
|
||||
c0 = cx
|
||||
cx = self.cond_block1(cx)
|
||||
c1 = cx
|
||||
cx = self.cond_block2(cx)
|
||||
c2 = cx
|
||||
cx = self.cond_block3(cx)
|
||||
c3 = cx
|
||||
return c0, c1, c2, c3
|
||||
x = self.cond_block0(x)
|
||||
x0 = x
|
||||
x = self.cond_block1(x)
|
||||
x1 = x
|
||||
x = self.cond_block2(x)
|
||||
x2 = x
|
||||
|
||||
def set_control(self, cond_embs):
|
||||
x_3d = []
|
||||
for x0 in [x0, x1, x2]:
|
||||
# b,c,h,w -> b,h*w,c
|
||||
n, c, h, w = x0.shape
|
||||
x0 = x0.view(n, c, h * w).permute(0, 2, 1)
|
||||
x_3d.append(x0)
|
||||
|
||||
return [x0, x1, x2], x_3d
|
||||
|
||||
def set_cond_embs(self, cond_embs_4d, cond_embs_3d):
|
||||
for lora in self.unet_loras:
|
||||
lora.set_control(cond_embs[lora.depth - 1])
|
||||
lora.set_cond_embs(cond_embs_4d, cond_embs_3d)
|
||||
|
||||
def set_batch_cond_uncond_enabled(self, enabled):
|
||||
for lora in self.unet_loras:
|
||||
lora.set_batch_cond_uncond_enabled(enabled)
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
@@ -228,18 +278,20 @@ class LoRAControlNet(torch.nn.Module):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sdxl_original_unet.USE_REENTRANT = False
|
||||
|
||||
# test shape etc
|
||||
print("create unet")
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||
unet.to("cuda") # , dtype=torch.float16)
|
||||
unet.to("cuda").to(torch.float16)
|
||||
|
||||
print("create LoRA controlnet")
|
||||
control_net = LoRAControlNet(unet, 16, 32, 1)
|
||||
control_net = LoRAControlNet(unet, 128, 32, 1)
|
||||
control_net.apply_to()
|
||||
control_net.to("cuda")
|
||||
|
||||
# print(controlnet)
|
||||
# input()
|
||||
print(control_net)
|
||||
input()
|
||||
|
||||
# print number of parameters
|
||||
print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad))
|
||||
@@ -282,8 +334,9 @@ if __name__ == "__main__":
|
||||
y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
cond_embs = control_net(conditioning_image)
|
||||
control_net.set_control(cond_embs)
|
||||
cond_embs_4d, cond_embs_3d = control_net(conditioning_image)
|
||||
control_net.set_cond_embs(cond_embs_4d, cond_embs_3d)
|
||||
|
||||
output = unet(x, t, ctx, y)
|
||||
target = torch.randn_like(output)
|
||||
loss = torch.nn.functional.mse_loss(output, target)
|
||||
|
||||
Reference in New Issue
Block a user