From 6d5cffaee97cd32e1de4f8ec4129102d2280e42d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 22 Aug 2023 08:17:21 +0900 Subject: [PATCH] add multiplier, steps range --- networks/control_net_lllite.py | 24 ++++++++++++++--- sdxl_gen_img.py | 48 ++++++++++++++++++++++------------ 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/networks/control_net_lllite.py b/networks/control_net_lllite.py index 3140919c..4ebfef7a 100644 --- a/networks/control_net_lllite.py +++ b/networks/control_net_lllite.py @@ -33,7 +33,7 @@ TRANSFORMER_MAX_BLOCK_INDEX = None class LLLiteModule(torch.nn.Module): - def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None): + def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None, multiplier=1.0): super().__init__() self.is_conv2d = org_module.__class__.__name__ == "Conv2d" @@ -41,6 +41,7 @@ class LLLiteModule(torch.nn.Module): self.cond_emb_dim = cond_emb_dim self.org_module = [org_module] self.dropout = dropout + self.multiplier = multiplier if self.is_conv2d: in_dim = org_module.in_channels @@ -119,6 +120,10 @@ class LLLiteModule(torch.nn.Module): 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む / call the model inside, so if necessary, surround it with torch.no_grad() """ + if cond_image is None: + self.cond_emb = None + return + # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance # print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}") cx = self.conditioning1(cond_image) @@ -141,6 +146,9 @@ class LLLiteModule(torch.nn.Module): 学習用の便利forward。元のモジュールのforwardを呼び出す / convenient forward for training. call the forward of the original module """ + if self.multiplier == 0.0 or self.cond_emb is None: + return self.org_forward(x) + cx = self.cond_emb if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only @@ -160,11 +168,13 @@ class LLLiteModule(torch.nn.Module): if self.dropout is not None and self.training: cx = torch.nn.functional.dropout(cx, p=self.dropout) - cx = self.up(cx) + cx = self.up(cx) * self.multiplier - # residua (x) lを加算して元のforwardを呼び出す / add residual (x) and call the original forward + # residual (x) を加算して元のforwardを呼び出す / add residual (x) and call the original forward if self.batch_cond_only: - cx = torch.zeros_like(x)[1::2] + cx + zx = torch.zeros_like(x) + zx[1::2] += cx + cx = zx x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here return x @@ -181,6 +191,7 @@ class ControlNetLLLite(torch.nn.Module): mlp_dim: int = 16, dropout: Optional[float] = None, varbose: Optional[bool] = False, + multiplier: Optional[float] = 1.0, ) -> None: super().__init__() # self.unets = [unet] @@ -264,6 +275,7 @@ class ControlNetLLLite(torch.nn.Module): child_module, mlp_dim, dropout=dropout, + multiplier=multiplier, ) modules.append(module) return modules @@ -291,6 +303,10 @@ class ControlNetLLLite(torch.nn.Module): for module in self.unet_modules: module.set_batch_cond_only(cond_only, zeros) + def set_multiplier(self, multiplier): + for module in self.unet_modules: + module.multiplier = multiplier + def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index cb16a781..0152578b 100644 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -661,21 +661,28 @@ class PipelineLike: if self.control_nets: # guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) if self.control_net_enabled: - for control_net in self.control_nets: + for control_net, _ in self.control_nets: with torch.no_grad(): control_net.set_cond_image(clip_guide_images) else: - for control_net in self.control_nets: + for control_net, _ in self.control_nets: control_net.set_cond_image(None) + each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) for i, t in enumerate(tqdm(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # # disable control net if ratio is set - # if self.control_nets and self.control_net_enabled: - # pass # TODO + # disable control net if ratio is set + if self.control_nets and self.control_net_enabled: + for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)): + if not enabled or ratio >= 1.0: + continue + if ratio < i / len(timesteps): + print(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + control_net.set_cond_image(None) + each_control_net_enabled[j] = False # predict the noise residual # TODO Diffusers' ControlNet @@ -1567,7 +1574,7 @@ def main(args): upscaler.to(dtype).to(device) # ControlNetの処理 - control_nets: List[ControlNetLLLite] = [] + control_nets: List[Tuple[ControlNetLLLite, float]] = [] # if args.control_net_models: # for i, model in enumerate(args.control_net_models): # prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] @@ -1595,12 +1602,19 @@ def main(args): break assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}" - control_net = ControlNetLLLite(unet, cond_emb_dim, mlp_dim) + multiplier = ( + 1.0 + if not args.control_net_multipliers or len(args.control_net_multipliers) <= i + else args.control_net_multipliers[i] + ) + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + control_net = ControlNetLLLite(unet, cond_emb_dim, mlp_dim, multiplier=multiplier) control_net.apply_to() control_net.load_state_dict(state_dict) control_net.to(dtype).to(device) control_net.set_batch_cond_only(False, False) - control_nets.append(control_net) + control_nets.append((control_net, ratio)) if args.opt_channels_last: print(f"set optimizing: channels last") @@ -2623,14 +2637,16 @@ def setup_parser() -> argparse.ArgumentParser: # parser.add_argument( # "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" # ) - # parser.add_argument("--control_net_multiplier", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率") - # parser.add_argument( - # "--control_net_ratios", - # type=float, - # default=None, - # nargs="*", - # help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", - # ) + parser.add_argument( + "--control_net_multipliers", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率" + ) + parser.add_argument( + "--control_net_ratios", + type=float, + default=None, + nargs="*", + help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + ) # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # )