mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
add cond/uncond, update config
This commit is contained in:
@@ -5,11 +5,12 @@ from networks.lora import LoRAModule, LoRANetwork
|
||||
from library import sdxl_original_unet
|
||||
|
||||
|
||||
SKIP_INPUT_BLOCKS = True
|
||||
SKIP_OUTPUT_BLOCKS = False
|
||||
SKIP_INPUT_BLOCKS = False
|
||||
SKIP_OUTPUT_BLOCKS = True
|
||||
SKIP_CONV2D = False
|
||||
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored
|
||||
ATTN1_ETC_ONLY = True
|
||||
TRANSFORMER_MAX_BLOCK_INDEX = 3 # None # 2 # None for all blocks
|
||||
|
||||
|
||||
class LoRAModuleControlNet(LoRAModule):
|
||||
@@ -48,15 +49,17 @@ class LoRAModuleControlNet(LoRAModule):
|
||||
|
||||
self.depth = depth
|
||||
self.cond_emb = None
|
||||
self.batch_cond_uncond_enabled = False
|
||||
self.batch_cond_only = False
|
||||
self.use_zeros_for_batch_uncond = False
|
||||
|
||||
def set_cond_embs(self, cond_embs_4d, cond_embs_3d):
|
||||
cond_embs = cond_embs_4d if self.is_conv2d else cond_embs_3d
|
||||
cond_emb = cond_embs[self.depth - 1]
|
||||
self.cond_emb = self.conditioning1(cond_emb)
|
||||
|
||||
def set_batch_cond_uncond_enabled(self, enabled):
|
||||
self.batch_cond_uncond_enabled = enabled
|
||||
def set_batch_cond_only(self, cond_only, zeros):
|
||||
self.batch_cond_only = cond_only
|
||||
self.use_zeros_for_batch_uncond = zeros
|
||||
|
||||
def forward(self, x):
|
||||
if self.cond_emb is None:
|
||||
@@ -64,7 +67,7 @@ class LoRAModuleControlNet(LoRAModule):
|
||||
|
||||
# LoRA
|
||||
lx = x
|
||||
if self.batch_cond_uncond_enabled:
|
||||
if self.batch_cond_only:
|
||||
lx = lx[1::2] # cond only
|
||||
|
||||
lx = self.lora_down(lx)
|
||||
@@ -75,6 +78,10 @@ class LoRAModuleControlNet(LoRAModule):
|
||||
# conditioning image
|
||||
cx = self.cond_emb
|
||||
# print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}")
|
||||
if not self.batch_cond_only and cx.shape[0] // 2 == lx.shape[0]: # inference only
|
||||
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
|
||||
if self.use_zeros_for_batch_uncond:
|
||||
cx[0::2] = 0.0 # uncond is zero
|
||||
|
||||
cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2)
|
||||
cx = self.conditioning2(cx)
|
||||
@@ -84,7 +91,7 @@ class LoRAModuleControlNet(LoRAModule):
|
||||
|
||||
x = self.org_forward(x)
|
||||
|
||||
if self.batch_cond_uncond_enabled:
|
||||
if self.batch_cond_only:
|
||||
x[1::2] += lx * self.multiplier * self.scale
|
||||
else:
|
||||
x += lx * self.multiplier * self.scale
|
||||
@@ -141,6 +148,13 @@ class LoRAControlNet(torch.nn.Module):
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
if TRANSFORMER_MAX_BLOCK_INDEX is not None:
|
||||
p = lora_name.find("transformer_blocks")
|
||||
if p >= 0:
|
||||
tf_index = int(lora_name[p:].split("_")[2])
|
||||
if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
|
||||
continue
|
||||
|
||||
# skip time emb or clip emb
|
||||
if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)):
|
||||
continue
|
||||
@@ -215,9 +229,9 @@ class LoRAControlNet(torch.nn.Module):
|
||||
for lora in self.unet_loras:
|
||||
lora.set_cond_embs(cond_embs_4d, cond_embs_3d)
|
||||
|
||||
def set_batch_cond_uncond_enabled(self, enabled):
|
||||
def set_batch_cond_only(self, cond_only, zeros):
|
||||
for lora in self.unet_loras:
|
||||
lora.set_batch_cond_uncond_enabled(enabled)
|
||||
lora.set_batch_cond_only(cond_only, zeros)
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
@@ -294,11 +308,12 @@ if __name__ == "__main__":
|
||||
control_net.to("cuda")
|
||||
|
||||
print(control_net)
|
||||
input()
|
||||
|
||||
# print number of parameters
|
||||
print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad))
|
||||
|
||||
input()
|
||||
|
||||
unet.set_use_memory_efficient_attention(True, False)
|
||||
unet.set_gradient_checkpointing(True)
|
||||
unet.train() # for gradient checkpointing
|
||||
|
||||
@@ -1571,7 +1571,7 @@ def main(args):
|
||||
control_net.apply_to()
|
||||
control_net.load_state_dict(state_dict)
|
||||
control_net.to(dtype).to(device)
|
||||
control_net.set_batch_cond_uncond_enabled(True)
|
||||
control_net.set_batch_cond_only(False, False)
|
||||
control_nets.append(control_net)
|
||||
|
||||
if args.opt_channels_last:
|
||||
|
||||
Reference in New Issue
Block a user