Files
Kohya-ss-sd-scripts/networks/lora_control_net.py
2023-08-17 10:19:14 +09:00

348 lines
12 KiB
Python

import os
from typing import Optional, List, Type
import torch
from networks.lora import LoRAModule, LoRANetwork
from library import sdxl_original_unet
SKIP_OUTPUT_BLOCKS = True
SKIP_CONV2D = False
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored
ATTN1_ETC_ONLY = True
class LoRAModuleControlNet(LoRAModule):
def __init__(self, depth, cond_emb_dim, name, org_module, multiplier, lora_dim, alpha, dropout=None):
super().__init__(name, org_module, multiplier, lora_dim, alpha, dropout=dropout)
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
self.cond_emb_dim = cond_emb_dim
if self.is_conv2d:
self.conditioning1 = torch.nn.Sequential(
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0),
torch.nn.ReLU(inplace=True),
)
self.conditioning2 = torch.nn.Sequential(
torch.nn.Conv2d(lora_dim + cond_emb_dim, cond_emb_dim, kernel_size=1, stride=1, padding=0),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(cond_emb_dim, lora_dim, kernel_size=1, stride=1, padding=0),
torch.nn.ReLU(inplace=True),
)
else:
self.conditioning1 = torch.nn.Sequential(
torch.nn.Linear(cond_emb_dim, cond_emb_dim),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(cond_emb_dim, cond_emb_dim),
torch.nn.ReLU(inplace=True),
)
self.conditioning2 = torch.nn.Sequential(
torch.nn.Linear(lora_dim + cond_emb_dim, cond_emb_dim),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(cond_emb_dim, lora_dim),
torch.nn.ReLU(inplace=True),
)
# torch.nn.init.zeros_(self.conditioning2[-2].weight) # zero conv
self.depth = depth
self.cond_emb = None
self.batch_cond_uncond_enabled = False
def set_cond_embs(self, cond_embs_4d, cond_embs_3d):
cond_embs = cond_embs_4d if self.is_conv2d else cond_embs_3d
cond_emb = cond_embs[self.depth - 1]
self.cond_emb = self.conditioning1(cond_emb)
def set_batch_cond_uncond_enabled(self, enabled):
self.batch_cond_uncond_enabled = enabled
def forward(self, x):
if self.cond_emb is None:
return self.org_forward(x)
# LoRA
lx = x
if self.batch_cond_uncond_enabled:
lx = lx[1::2] # cond only
lx = self.lora_down(lx)
if self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(lx, p=self.dropout)
# conditioning image
cx = self.cond_emb
# print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}")
cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2)
cx = self.conditioning2(cx)
lx = lx + cx
lx = self.lora_up(lx)
x = self.org_forward(x)
if self.batch_cond_uncond_enabled:
x[1::2] += lx * self.multiplier * self.scale
else:
x += lx * self.multiplier * self.scale
return x
class LoRAControlNet(torch.nn.Module):
def __init__(
self,
unet: sdxl_original_unet.SdxlUNet2DConditionModel,
cond_emb_dim: int = 16,
lora_dim: int = 16,
alpha: float = 1,
dropout: Optional[float] = None,
varbose: Optional[bool] = False,
) -> None:
super().__init__()
# self.unets = [unet]
def create_modules(
root_module: torch.nn.Module,
target_replace_modules: List[torch.nn.Module],
module_class: Type[object],
) -> List[torch.nn.Module]:
prefix = LoRANetwork.LORA_PREFIX_UNET
loras = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
if is_linear or (is_conv2d and not SKIP_CONV2D):
# block index to depth: depth is using to calculate conditioning size and channels
block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
index1 = int(index1)
if block_name == "input_blocks":
depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
elif block_name == "middle_block":
depth = 3
elif block_name == "output_blocks":
if SKIP_OUTPUT_BLOCKS:
continue
depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
if int(index2) >= 2:
depth -= 1
else:
raise NotImplementedError()
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
# skip time emb or clip emb
if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)):
continue
if ATTN1_ETC_ONLY:
if "proj_out" in lora_name:
pass
elif "attn1" in lora_name and ("to_k" in lora_name or "to_v" in lora_name or "to_out" in lora_name):
pass
elif "ff_net_2" in lora_name:
pass
else:
continue
lora = module_class(
depth,
cond_emb_dim,
lora_name,
child_module,
1.0,
lora_dim,
alpha,
dropout=dropout,
)
loras.append(lora)
return loras
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
if not TRANSFORMER_ONLY:
target_modules = target_modules + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
# create module instances
self.unet_loras: List[LoRAModuleControlNet] = create_modules(unet, target_modules, LoRAModuleControlNet)
print(f"create ControlNet LoRA for U-Net: {len(self.unet_loras)} modules.")
# conditioning image embedding
self.cond_block0 = torch.nn.Sequential(
torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0), # to latent size
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=3, stride=2, padding=1),
torch.nn.ReLU(inplace=True),
)
self.cond_block1 = torch.nn.Sequential(
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=2, padding=1),
torch.nn.ReLU(inplace=True),
)
self.cond_block2 = torch.nn.Sequential(
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=2, padding=1),
torch.nn.ReLU(inplace=True),
)
def forward(self, x):
x = self.cond_block0(x)
x0 = x
x = self.cond_block1(x)
x1 = x
x = self.cond_block2(x)
x2 = x
x_3d = []
for x0 in [x0, x1, x2]:
# b,c,h,w -> b,h*w,c
n, c, h, w = x0.shape
x0 = x0.view(n, c, h * w).permute(0, 2, 1)
x_3d.append(x0)
return [x0, x1, x2], x_3d
def set_cond_embs(self, cond_embs_4d, cond_embs_3d):
for lora in self.unet_loras:
lora.set_cond_embs(cond_embs_4d, cond_embs_3d)
def set_batch_cond_uncond_enabled(self, enabled):
for lora in self.unet_loras:
lora.set_batch_cond_uncond_enabled(enabled)
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def apply_to(self):
print("applying LoRA for U-Net...")
for lora in self.unet_loras:
lora.apply_to()
self.add_module(lora.lora_name, lora)
# マージできるかどうかを返す
def is_mergeable(self):
return False
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
raise NotImplementedError()
def enable_gradient_checkpointing(self):
# not supported
pass
def prepare_optimizer_params(self):
self.requires_grad_(True)
return self.parameters()
def prepare_grad_etc(self):
self.requires_grad_(True)
def on_epoch_start(self):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
if __name__ == "__main__":
sdxl_original_unet.USE_REENTRANT = False
# test shape etc
print("create unet")
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
unet.to("cuda").to(torch.float16)
print("create LoRA controlnet")
control_net = LoRAControlNet(unet, 256, 64, 1)
control_net.apply_to()
control_net.to("cuda")
print(control_net)
input()
# print number of parameters
print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad))
unet.set_use_memory_efficient_attention(True, False)
unet.set_gradient_checkpointing(True)
unet.train() # for gradient checkpointing
control_net.train()
# # visualize
# import torchviz
# print("run visualize")
# controlnet.set_control(conditioning_image)
# output = unet(x, t, ctx, y)
# print("make_dot")
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
# print("render")
# image.format = "svg" # "png"
# image.render("NeuralNet")
# input()
import bitsandbytes
optimizer = bitsandbytes.adam.Adam8bit(control_net.prepare_optimizer_params(), 1e-3)
scaler = torch.cuda.amp.GradScaler(enabled=True)
print("start training")
steps = 10
for step in range(steps):
print(f"step {step}")
batch_size = 1
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
x = torch.randn(batch_size, 4, 128, 128).cuda()
t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
ctx = torch.randn(batch_size, 77, 2048).cuda()
y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
with torch.cuda.amp.autocast(enabled=True):
cond_embs_4d, cond_embs_3d = control_net(conditioning_image)
control_net.set_cond_embs(cond_embs_4d, cond_embs_3d)
output = unet(x, t, ctx, y)
target = torch.randn_like(output)
loss = torch.nn.functional.mse_loss(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)